Skip to content

Commit

Permalink
all layer normalised and ablation on density embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Sep 22, 2024
1 parent e70cb2c commit ffe3b89
Show file tree
Hide file tree
Showing 22 changed files with 1,652 additions and 77 deletions.
99 changes: 50 additions & 49 deletions mace/cli/plot_neighbor_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def main() -> None:
args = tools.build_default_arg_parser().parse_args()
tag = tools.get_tag(name=args.name, seed=args.seed)

log_tag = tag + "_save_model_and_plot"

if args.device == "xpu":
try:
import intel_extension_for_pytorch as ipex
Expand Down Expand Up @@ -92,7 +94,7 @@ def main() -> None:

# Setup
tools.set_seeds(args.seed)
tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank)
tools.setup_logger(level=args.log_level, tag=log_tag, directory=args.log_dir, rank=rank)

if args.distributed:
torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -351,7 +353,7 @@ def main() -> None:
if 'avg_num_neighbors' in head_args and head_args.avg_num_neighbors > 0:
head_args.compute_avg_num_neighbors = False

if head_args.get("plot_neighbor_distribution", True):
if head_args.get("plot_neighbor_distribution", False):
if args.distributed:
plot_avg = False
plot_bar = True
Expand Down Expand Up @@ -700,9 +702,7 @@ def plot_species_neighbors(raw, element_symbols):
distance_transform=args.distance_transform,
correlation=args.correlation,
gate=modules.gate_dict[args.gate],
interaction_cls_first=modules.interaction_classes[
"RealAgnosticInteractionBlock"
],
interaction_cls_first=modules.interaction_classes[args.interaction_first],
MLP_irreps=o3.Irreps(args.MLP_irreps),
atomic_inter_scale=[v.std for v in args.heads.values()],
atomic_inter_shift=[0.0 for v in args.heads.values()],
Expand Down Expand Up @@ -861,7 +861,7 @@ def plot_species_neighbors(raw, element_symbols):
if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train")
logger = tools.MetricsLogger(directory=args.results_dir, tag=log_tag + "_train")

lr_scheduler = LRScheduler(optimizer, args)

Expand Down Expand Up @@ -940,6 +940,7 @@ def plot_species_neighbors(raw, element_symbols):
device=device,
)
except Exception: # pylint: disable=W0703
print(model)
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
Expand Down Expand Up @@ -1079,50 +1080,50 @@ def plot_species_neighbors(raw, element_symbols):
# )
# all_data_loaders[test_name] = test_loader

#for swa_eval in swas:
# epoch = checkpoint_handler.load_latest(
# state=tools.CheckpointState(model, optimizer, lr_scheduler),
# swa=swa_eval,
# device=device,
# )
# model.to(device)
# if args.distributed:
# distributed_model = DDP(model, device_ids=[local_rank])
# model_to_evaluate = model if not args.distributed else distributed_model
# logging.info(f"Loaded model from epoch {epoch}")

# for param in model.parameters():
# param.requires_grad = False
# table = create_error_table(
# table_type=args.error_table,
# all_data_loaders=all_data_loaders,
# model=model_to_evaluate,
# loss_fn=loss_fn,
# output_args=output_args,
# log_wandb=args.wandb,
# device=device,
# distributed=args.distributed,
# )
# logging.info("\n" + str(table))

# if rank == 0:
# # Save entire model
# if swa_eval:
# model_path = Path(args.checkpoints_dir) / (tag + "_swa.model")
# else:
# model_path = Path(args.checkpoints_dir) / (tag + ".model")
# logging.info(f"Saving model to {model_path}")
# if args.save_cpu:
# model = model.to("cpu")
# torch.save(model, model_path)

# if swa_eval:
# torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
# else:
# torch.save(model, Path(args.model_dir) / (args.name + ".model"))
for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
model.to(device)
if args.distributed:
distributed_model = DDP(model, device_ids=[local_rank])
model_to_evaluate = model if not args.distributed else distributed_model
logging.info(f"Loaded model from epoch {epoch}")

#for param in model.parameters():
# param.requires_grad = False
#table = create_error_table(
# table_type=args.error_table,
# all_data_loaders=all_data_loaders,
# model=model_to_evaluate,
# loss_fn=loss_fn,
# output_args=output_args,
# log_wandb=args.wandb,
# device=device,
# distributed=args.distributed,
#)
#logging.info("\n" + str(table))

# if args.distributed:
# torch.distributed.barrier()
if rank == 0:
# Save entire model
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_swa.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)

if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))

if args.distributed:
torch.distributed.barrier()

logging.info("Done")
if args.distributed:
Expand Down
10 changes: 7 additions & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ def main() -> None:
else:
no_decay_interactions[name] = param

#import ipdb; ipdb.set_trace()

param_options = dict(
params=[
{
Expand Down Expand Up @@ -841,6 +843,11 @@ def main() -> None:
)

start_epoch = 0

logging.info(model)
logging.info(f"Number of parameters: {tools.count_parameters(model)}")
logging.info(f"Optimizer: {optimizer}")

if args.restart_latest:
try:
opt_start_epoch = checkpoint_handler.load_latest(
Expand All @@ -864,9 +871,6 @@ def main() -> None:
for group in optimizer.param_groups:
group["lr"] = args.lr

logging.info(model)
logging.info(f"Number of parameters: {tools.count_parameters(model)}")
logging.info(f"Optimizer: {optimizer}")

if args.wandb:
logging.info("Using Weights and Biases for logging")
Expand Down
14 changes: 14 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@
RealAgnosticInteractionGateBlock,
RealAgnosticResidualInteractionGateBlock,
RealAgnosticDensityNormalizedInteractionGateBlock,
RealAgnosticDensityNormalizedNoScaleInteractionGateBlock,
RealAgnosticDensityInjuctedInteractionGateBlock,
RealAgnosticDensityInjuctedNoScaleInteractionGateBlock,
RealAgnosticDensityInjuctedNoScaleNoBiasInteractionGateBlock,
RealAgnosticDensityInjuctedNodeAttrAttendInteractionGateBlock,
RealAgnosticDensityInjuctedNodeAttrAttendResidualInteractionGateBlock,
RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock,
RealAgnosticDensityInjuctedNoScaleResidualInteractionGateBlock,
RealAgnosticDensityInjuctUnnormalizedNoScaleInteractionGateBlock,


#RealAgnosticDensityNormalizedResidualInteractionGateBlock,
ResidualElementDependentInteractionBlock,
ScaleShiftBlock,
Expand Down Expand Up @@ -70,7 +78,13 @@
"RealAgnosticDensityNormalizedInteractionBlock": RealAgnosticDensityNormalizedInteractionGateBlock,
"RealAgnosticDensityInjuctedInteractionBlock": RealAgnosticDensityInjuctedInteractionGateBlock,
"RealAgnosticDensityInjuctedNoScaleInteractionBlock": RealAgnosticDensityInjuctedNoScaleInteractionGateBlock,
"RealAgnosticDensityInjuctedNoScaleNoBiasInteractionBlock": RealAgnosticDensityInjuctedNoScaleNoBiasInteractionGateBlock,
"RealAgnosticDensityInjuctedNodeAttrAttendInteractionBlock": RealAgnosticDensityInjuctedNodeAttrAttendInteractionGateBlock,
"RealAgnosticDensityInjuctedNodeAttrAttendResidualInteractionBlock": RealAgnosticDensityInjuctedNodeAttrAttendResidualInteractionGateBlock,
"RealAgnosticDensityNormalizedNoScaleInteractionBlock": RealAgnosticDensityNormalizedNoScaleInteractionGateBlock,
"RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionBlock": RealAgnosticDensityInjuctedNoScaleNoBiasResidualInteractionGateBlock,
"RealAgnosticDensityInjuctedNoScaleResidualInteractionBlock": RealAgnosticDensityInjuctedNoScaleResidualInteractionGateBlock,
"RealAgnosticDensityInjuctUnnormalizedNoScaleInteractionBlock": RealAgnosticDensityInjuctUnnormalizedNoScaleInteractionGateBlock,
}

scaling_classes: Dict[str, Callable] = {
Expand Down
Loading

0 comments on commit ffe3b89

Please sign in to comment.