From 93435475bbdcb5ba739117a1fabaab3e1793c7b8 Mon Sep 17 00:00:00 2001 From: jharrymoore Date: Fri, 8 Mar 2024 12:47:38 +0000 Subject: [PATCH 001/101] small changes for intel GPUs --- mace/cli/run_train.py | 11 ++++++++++- mace/tools/arg_parser.py | 2 +- mace/tools/torch_tools.py | 4 ++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c63f1c4d..00c62d39 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -29,6 +29,11 @@ def main() -> None: args = tools.build_default_arg_parser().parse_args() tag = tools.get_tag(name=args.name, seed=args.seed) + if args.device == "xpu": + try: + import intel_extension_for_pytorch as ipex + except ImportError: + raise ImportError("Error: Intel extension for PyTorch not found, but XPU device was specified") # Setup tools.set_seeds(args.seed) @@ -333,6 +338,7 @@ def main() -> None: else: raise RuntimeError(f"Unknown model: '{args.model}'") + model.to(device) # Optimizer @@ -381,9 +387,12 @@ def main() -> None: optimizer = torch.optim.AdamW(**param_options) else: optimizer = torch.optim.Adam(**param_options) - + if args.device == "xpu": + logging.info("Optimzing model and optimzier for XPU") + model, optimizer = ipex.optimize(model, optimizer=optimizer) logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train") + lr_scheduler = LRScheduler(optimizer, args) swa: Optional[tools.SWAContainer] = None diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index e16be03f..19175a6c 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -41,7 +41,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--device", help="select device", type=str, - choices=["cpu", "cuda", "mps"], + choices=["cpu", "cuda", "mps", "xpu"], default="cpu", ) parser.add_argument( diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index 349f1e3b..41926222 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -63,6 +63,10 @@ def init_device(device_str: str) -> torch.device: assert torch.backends.mps.is_available(), "No MPS backend is available!" logging.info("Using MPS GPU acceleration") return torch.device("mps") + if device_str == "xpu": + torch.xpu.is_available() + return torch.device("xpu") + logging.info("Using CPU") return torch.device("cpu") From 054952b55e59d755e9ed3b5b8aa29741deb5c297 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:38:51 +0000 Subject: [PATCH 002/101] add theory --- mace/cli/run_train.py | 10 +++++++--- mace/data/atomic_data.py | 15 ++++++++++++++- mace/data/utils.py | 14 +++++++++++++- mace/modules/blocks.py | 14 ++++++++------ mace/modules/models.py | 3 ++- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 2401ba6b..9ebc40ba 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -91,7 +91,7 @@ def main() -> None: ".xyz" ), "valid_file if given must be same format as train_file" config_type_weights = get_config_type_weights(args.config_type_weights) - collections, atomic_energies_dict = get_dataset_from_xyz( + collections, atomic_energies_dict, theories = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -192,11 +192,15 @@ def main() -> None: if args.train_file.endswith(".xyz"): train_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, theories=theories + ) for config in collections.train ] valid_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, theories=theories + ) for config in collections.valid ] elif args.train_file.endswith(".h5"): diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index edb91b14..e4e6b46a 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -52,6 +52,7 @@ def __init__( unit_shifts: torch.Tensor, # [n_edges, 3] cell: Optional[torch.Tensor], # [3,3] weight: Optional[torch.Tensor], # [,] + theory: Optional[torch.Tensor], # [,] energy_weight: Optional[torch.Tensor], # [,] forces_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,] @@ -72,6 +73,7 @@ def __init__( assert unit_shifts.shape[1] == 3 assert len(node_attrs.shape) == 2 assert weight is None or len(weight.shape) == 0 + assert theory is None or len(theory.shape) == 0 assert energy_weight is None or len(energy_weight.shape) == 0 assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 @@ -93,6 +95,7 @@ def __init__( "cell": cell, "node_attrs": node_attrs, "weight": weight, + "theory": theory, "energy_weight": energy_weight, "forces_weight": forces_weight, "stress_weight": stress_weight, @@ -108,7 +111,11 @@ def __init__( @classmethod def from_config( - cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float + cls, + config: Configuration, + z_table: AtomicNumberTable, + cutoff: float, + theories: set, ) -> "AtomicData": edge_index, shifts, unit_shifts = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell @@ -119,6 +126,11 @@ def from_config( num_classes=len(z_table), ) + theory = torch.nn.functional.one_hot( + torch.tensor([theories.index(config.theory)], dtype=torch.long), + num_classes=len(theories), + ).squeeze(0) + cell = ( torch.tensor(config.cell, dtype=torch.get_default_dtype()) if config.cell is not None @@ -200,6 +212,7 @@ def from_config( cell=cell, node_attrs=one_hot, weight=weight, + theory=theory, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, diff --git a/mace/data/utils.py b/mace/data/utils.py index 74363ff8..481de2bb 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -47,6 +47,7 @@ class Configuration: stress_weight: float = 1.0 # weight of config stress in loss virials_weight: float = 1.0 # weight of config virial in loss config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config + theory: Optional[str] = "Default" # theory used to compute the config Configurations = List[Configuration] @@ -78,6 +79,7 @@ def config_from_atoms_list( virials_key="virials", dipole_key="dipole", charges_key="charges", + theory_key="theory", config_type_weights: Dict[str, float] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -95,6 +97,7 @@ def config_from_atoms_list( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + theory_key=theory_key, config_type_weights=config_type_weights, ) ) @@ -109,6 +112,7 @@ def config_from_atoms( virials_key="virials", dipole_key="dipole", charges_key="charges", + theory_key="theory", config_type_weights: Dict[str, float] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -136,6 +140,8 @@ def config_from_atoms( stress_weight = atoms.info.get("config_stress_weight", 1.0) virials_weight = atoms.info.get("config_virials_weight", 1.0) + theory = atoms.info.get(theory_key, "Default") + # fill in missing quantities but set their weight to 0.0 if energy is None: energy = 0.0 @@ -163,6 +169,7 @@ def config_from_atoms( dipole=dipole, charges=charges, weight=weight, + theory=theory, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, @@ -198,6 +205,7 @@ def load_from_xyz( virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", + theory_key: str = "theory", extract_atomic_energies: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") @@ -232,6 +240,9 @@ def load_from_xyz( logging.info("Using isolated atom energies from training file") atoms_list = atoms_without_iso_atoms + theories = set() + for atoms in atoms_list: + theories.add(atoms.info.get(theory_key, "Default")) configs = config_from_atoms_list( atoms_list, @@ -242,8 +253,9 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + theory_key=theory_key, ) - return atomic_energies_dict, configs + return atomic_energies_dict, configs, theories def compute_average_E0s( diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index fa0c1085..b95c893c 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -44,9 +44,9 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps): + def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=o3.Irreps("0e")) + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -55,15 +55,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [... @compile_mode("script") class NonLinearReadoutBlock(torch.nn.Module): def __init__( - self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable] + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps = o3.Irreps("0e"), ): super().__init__() self.hidden_irreps = MLP_irreps self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") - ) + self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) diff --git a/mace/modules/models.py b/mace/modules/models.py index 776b33a2..229e9fe3 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -61,6 +61,7 @@ def __init__( distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", + theories: Optional[List[str]] = ["Default"], ): super().__init__() self.register_buffer( @@ -161,7 +162,7 @@ def __init__( self.products.append(prod) if i == num_interactions - 2: self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate, ) ) else: self.readouts.append(LinearReadoutBlock(hidden_irreps)) From 3bd561182c7b0df8e54084bfe6ac0845e8b5a878 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 18 Mar 2024 20:40:53 +0000 Subject: [PATCH 003/101] fix different level of theory --- mace/data/atomic_data.py | 8 +++----- mace/data/utils.py | 2 +- mace/modules/models.py | 21 +++++++++++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index e4e6b46a..7ed172ba 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -115,7 +115,7 @@ def from_config( config: Configuration, z_table: AtomicNumberTable, cutoff: float, - theories: set, + theories: Optional[list] = ["Default"], ) -> "AtomicData": edge_index, shifts, unit_shifts = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell @@ -126,10 +126,8 @@ def from_config( num_classes=len(z_table), ) - theory = torch.nn.functional.one_hot( - torch.tensor([theories.index(config.theory)], dtype=torch.long), - num_classes=len(theories), - ).squeeze(0) + theory = torch.tensor(theories.index(config.theory), dtype=torch.long) + print("theory", theory) cell = ( torch.tensor(config.cell, dtype=torch.get_default_dtype()) diff --git a/mace/data/utils.py b/mace/data/utils.py index 481de2bb..09201339 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -243,7 +243,7 @@ def load_from_xyz( theories = set() for atoms in atoms_list: theories.add(atoms.info.get(theory_key, "Default")) - + theories = list(theories) configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, diff --git a/mace/modules/models.py b/mace/modules/models.py index 229e9fe3..cc981dc9 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -162,10 +162,17 @@ def __init__( self.products.append(prod) if i == num_interactions - 2: self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate, ) + NonLinearReadoutBlock( + hidden_irreps_out, + (len(theories) * o3.Irreps("16x0e")).simplify(), + gate, + o3.Irreps(f"{len(theories)}x0e"), + ) ) else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(theories)}x0e")) + ) def forward( self, @@ -179,6 +186,7 @@ def forward( # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) + print("theory", data["theory"]) num_graphs = data["ptr"].numel() - 1 displacement = torch.zeros( (num_graphs, 3, 3), @@ -246,9 +254,12 @@ def forward( node_attrs=data["node_attrs"], ) node_feats_list.append(node_feats) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = readout(node_feats) # [n_nodes, ] energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + src=node_energies[:, data["theory"]], + index=data["batch"], + dim=-1, + dim_size=num_graphs, ) # [n_graphs,] energies.append(energy) node_energies_list.append(node_energies) @@ -311,6 +322,7 @@ def forward( # Setup data["positions"].requires_grad_(True) data["node_attrs"].requires_grad_(True) + print("theory", data["theory"]) num_graphs = data["ptr"].numel() - 1 displacement = torch.zeros( (num_graphs, 3, 3), @@ -372,6 +384,7 @@ def forward( ) node_feats_list.append(node_feats) node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) # print("node_es_list", node_es_list) From 0cc5dd44fdb528e2a877420097fff205bb116ab4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 22 Mar 2024 12:10:43 +0000 Subject: [PATCH 004/101] add E0 multi reference --- mace/modules/blocks.py | 8 +++++--- mace/modules/models.py | 33 ++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index b95c893c..982fe1cc 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -132,12 +132,14 @@ class AtomicEnergiesBlock(torch.nn.Module): def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): super().__init__() - assert len(atomic_energies.shape) == 1 + # assert len(atomic_energies.shape) == 1 self.register_buffer( "atomic_energies", - torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), - ) # [n_elements, ] + torch.atleast_2d( + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()) + ), + ) # [n_elements, n_theories] def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] diff --git a/mace/modules/models.py b/mace/modules/models.py index cc981dc9..f4f12da2 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -132,7 +132,9 @@ def __init__( self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(theories)}x0e")) + ) for i in range(num_interactions - 1): if i == num_interactions - 2: @@ -209,9 +211,9 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + e0 = scatter_sum(src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs)[ + torch.arange(num_graphs, data["theory"]) + ] # [n_graphs, n_theories] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -254,11 +256,13 @@ def forward( node_attrs=data["node_attrs"], ) node_feats_list.append(node_feats) - node_energies = readout(node_feats) # [n_nodes, ] + node_energies = readout(node_feats)[ + torch.arange(node_feats.shape[0]), data["theory"][data["batch"]] + ] # [n_nodes, len(theories)] energy = scatter_sum( - src=node_energies[:, data["theory"]], + src=node_energies, index=data["batch"], - dim=-1, + dim=0, dim_size=num_graphs, ) # [n_graphs,] energies.append(energy) @@ -322,7 +326,6 @@ def forward( # Setup data["positions"].requires_grad_(True) data["node_attrs"].requires_grad_(True) - print("theory", data["theory"]) num_graphs = data["ptr"].numel() - 1 displacement = torch.zeros( (num_graphs, 3, 3), @@ -345,9 +348,9 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + e0 = scatter_sum(src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs)[ + torch.arange(num_graphs), data["theory"] + ] # [n_graphs, num_theories] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) @@ -383,7 +386,11 @@ def forward( node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] ) node_feats_list.append(node_feats) - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + node_es_list.append( + readout(node_feats)[ + torch.arange(node_feats.shape[0]), data["theory"][data["batch"]] + ].squeeze(-1) + ) # {[n_nodes, ], } # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) @@ -507,7 +514,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_e0 = self.atomic_energies_fn(data.node_attrs) e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] + ) # [n_graphs, n_theories] # Embeddings node_feats = self.node_embedding(data.node_attrs) From 73354e7323b89d967fb1732f5ca6cb41d2e8cd80 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 23 Mar 2024 12:22:38 +0000 Subject: [PATCH 005/101] fix multi reference E0 and model --- mace/calculators/mace.py | 4 ++- mace/cli/run_train.py | 38 +++++--------------- mace/data/utils.py | 54 +++++++++++++++++------------ mace/modules/blocks.py | 23 ++++++++----- mace/modules/models.py | 57 +++++++++++++++++++----------- mace/modules/utils.py | 69 +++++++++++++++++++++++++++---------- mace/tools/scripts_utils.py | 25 +++++++++++--- tests/test_models.py | 53 ++++++++++++++++++++++++++++ tests/test_modules.py | 2 +- 9 files changed, 220 insertions(+), 105 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 843883e4..68922b4d 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -203,7 +203,9 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = next(iter(data_loader)).to(self.device) - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + torch.arange(batch.num_nodes), batch.theory[batch.batch] + ] compute_stress = True else: compute_stress = False diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9ebc40ba..1e37304c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -31,6 +31,7 @@ get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, + dict_to_array, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import load_foundations @@ -137,10 +138,12 @@ def main() -> None: if atomic_energies_dict is None or len(atomic_energies_dict) == 0: if args.train_file.endswith(".xyz"): atomic_energies_dict = get_atomic_energies( - args.E0s, collections.train, z_table + args.E0s, collections.train, z_table, theories ) else: - atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table) + atomic_energies_dict = get_atomic_energies( + args.E0s, None, z_table, theories + ) if args.model == "AtomicDipolesMACE": atomic_energies = None @@ -161,33 +164,10 @@ def main() -> None: else: compute_energy = True compute_dipole = False - if atomic_energies_dict is None or len(atomic_energies_dict) == 0: - if args.E0s is not None: - logging.info( - "Atomic Energies not in training file, using command line argument E0s" - ) - if args.E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - atomic_energies_dict = data.compute_average_E0s( - collections.train, z_table - ) - else: - try: - atomic_energies_dict = ast.literal_eval(args.E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occurred" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) + # atomic_energies: np.ndarray = np.array( + # [atomic_energies_dict[z] for z in z_table.zs] + # ) + atomic_energies = dict_to_array(atomic_energies_dict) logging.info(f"Atomic energies: {atomic_energies.tolist()}") if args.train_file.endswith(".xyz"): diff --git a/mace/data/utils.py b/mace/data/utils.py index 09201339..2ccdc136 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -207,6 +207,7 @@ def load_from_xyz( charges_key: str = "charges", theory_key: str = "theory", extract_atomic_energies: bool = False, + keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") @@ -222,7 +223,10 @@ def load_from_xyz( isolated_atom_config = atoms.info.get("config_type") == "IsolatedAtom" if isolated_atom_config: if energy_key in atoms.info.keys(): - atomic_energies_dict[atoms.get_atomic_numbers()[0]] = ( + theory = atoms.info.get(theory_key, "Default") + if theory not in atomic_energies_dict: + atomic_energies_dict[theory] = {} + atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( atoms.info[energy_key] ) else: @@ -230,16 +234,19 @@ def load_from_xyz( f"Configuration '{idx}' is marked as 'IsolatedAtom' " "but does not contain an energy. Zero energy will be used." ) - atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros( - 1 + theory = atoms.info.get(theory_key, "Default") + if theory not in atomic_energies_dict: + atomic_energies_dict[theory] = {} + atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( + np.zeros(1) ) else: atoms_without_iso_atoms.append(atoms) if len(atomic_energies_dict) > 0: logging.info("Using isolated atom energies from training file") - - atoms_list = atoms_without_iso_atoms + if not keep_isolated_atoms: + atoms_list = atoms_without_iso_atoms theories = set() for atoms in atoms_list: theories.add(atoms.info.get(theory_key, "Default")) @@ -259,7 +266,7 @@ def load_from_xyz( def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable + collections_train: Configurations, z_table: AtomicNumberTable, theories: List[str] ) -> Dict[int, float]: """ Function to compute the average interaction energy of each chemical element @@ -269,22 +276,25 @@ def compute_average_E0s( len_zs = len(z_table) A = np.zeros((len_train, len_zs)) B = np.zeros(len_train) - for i in range(len_train): - B[i] = collections_train[i].energy - for j, z in enumerate(z_table.zs): - A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) - try: - E0s = np.linalg.lstsq(A, B, rcond=None)[0] - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = E0s[i] - except np.linalg.LinAlgError: - logging.warning( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = 0.0 + for theory in theories: + for i in range(len_train): + if collections_train[i].theory != theory: + continue + B[i] = collections_train[i].energy + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[theory][z] = E0s[i] + except np.linalg.LinAlgError: + logging.warning( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[theory][z] = 0.0 return atomic_energies_dict diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 982fe1cc..94216276 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -136,15 +136,16 @@ def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): self.register_buffer( "atomic_energies", - torch.atleast_2d( - torch.tensor(atomic_energies, dtype=torch.get_default_dtype()) - ), + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), ) # [n_elements, n_theories] def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] ) -> torch.Tensor: # [..., ] - return torch.matmul(x, self.atomic_energies) + print("self.atomic_energies.T", torch.atleast_2d(self.atomic_energies).T) + print("self.atomic_energies.T", self.atomic_energies.T.shape) + print("x", x.shape) + return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) def __repr__(self): formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) @@ -742,14 +743,20 @@ class ScaleShiftBlock(torch.nn.Module): def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( - "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) + "scale", + torch.atleast_1d(torch.tensor(scale, dtype=torch.get_default_dtype())), ) self.register_buffer( - "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) + "shift", + torch.atleast_1d(torch.tensor(shift, dtype=torch.get_default_dtype())), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.scale * x + self.shift + def forward(self, x: torch.Tensor, theory: torch.Tensor) -> torch.Tensor: + print("theory", theory.shape) + print("x", x.shape) + print("self.scale", self.scale.shape) + print("self.shift", self.shift.shape) + return self.scale[theory] * x + self.shift[theory] def __repr__(self): return ( diff --git a/mace/modules/models.py b/mace/modules/models.py index f4f12da2..a4d12780 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -189,6 +189,7 @@ def forward( data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) print("theory", data["theory"]) + num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 displacement = torch.zeros( (num_graphs, 3, 3), @@ -210,10 +211,13 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum(src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs)[ - torch.arange(num_graphs, data["theory"]) - ] # [n_graphs, n_theories] + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["theory"][data["batch"]] + ] + print("node e0", node_e0.shape) + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, n_theories] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -237,6 +241,8 @@ def forward( pair_energy = torch.zeros_like(e0) # Interactions + print("pair_energy", pair_energy) + print("pair_node_energy", pair_node_energy) energies = [e0, pair_energy] node_energies_list = [node_e0, pair_node_energy] node_feats_list = [] @@ -257,7 +263,7 @@ def forward( ) node_feats_list.append(node_feats) node_energies = readout(node_feats)[ - torch.arange(node_feats.shape[0]), data["theory"][data["batch"]] + num_atoms_arange, data["theory"][data["batch"]] ] # [n_nodes, len(theories)] energy = scatter_sum( src=node_energies, @@ -267,7 +273,7 @@ def forward( ) # [n_graphs,] energies.append(energy) node_energies_list.append(node_energies) - + print("node_energies", node_energies) # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) @@ -327,6 +333,7 @@ def forward( data["positions"].requires_grad_(True) data["node_attrs"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -347,10 +354,12 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) - e0 = scatter_sum(src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs)[ - torch.arange(num_graphs), data["theory"] - ] # [n_graphs, num_theories] + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["theory"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, num_theories] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) @@ -387,19 +396,18 @@ def forward( ) node_feats_list.append(node_feats) node_es_list.append( - readout(node_feats)[ - torch.arange(node_feats.shape[0]), data["theory"][data["batch"]] - ].squeeze(-1) + readout(node_feats)[num_atoms_arange, data["theory"][data["batch"]]] ) # {[n_nodes, ], } # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) - # print("node_es_list", node_es_list) + print("node_es_list", node_es_list) # Sum over interactions node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + print("node_inter_es", node_inter_es.shape) + node_inter_es = self.scale_shift(node_inter_es, data["theory"][data["batch"]]) # Sum over nodes in graph inter_e = scatter_sum( @@ -409,6 +417,7 @@ def forward( # Add E_0 and (scaled) interaction energy total_energy = e0 + inter_e node_energy = node_e0 + node_inter_es + print("node_energy", node_energy.shape) forces, virials, stress = get_outputs( energy=inter_e, positions=data["positions"], @@ -509,9 +518,12 @@ def __init__( def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Setup data.positions.requires_grad = True + num_atoms_arange = torch.arange(data.positions.shape[0]) # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["theory"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs ) # [n_graphs, n_theories] @@ -572,9 +584,11 @@ def __init__( def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Setup data.positions.requires_grad = True - + num_atoms_arange = torch.arange(data.positions.shape[0]) # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["theory"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs ) # [n_graphs,] @@ -606,7 +620,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + node_inter_es = self.scale_shift(node_inter_es, data["theory"][data["batch"]]) # Sum over nodes in graph inter_e = scatter_sum( @@ -966,6 +980,7 @@ def forward( data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -986,7 +1001,9 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["theory"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] diff --git a/mace/modules/utils.py b/mace/modules/utils.py index caac465d..07cf6d7b 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -14,7 +14,7 @@ from scipy.constants import c, e from mace.tools import to_numpy -from mace.tools.scatter import scatter_sum +from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum from mace.tools.torch_geometric.batch import Batch from .blocks import AtomicEnergiesBlock @@ -192,20 +192,25 @@ def compute_mean_std_atomic_inter_energy( atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) avg_atom_inter_es_list = [] + theory_list = [] for batch in data_loader: node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), batch.theory] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] avg_atom_inter_es_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } + theory_list.append(batch.theory) avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - mean = to_numpy(torch.mean(avg_atom_inter_es)).item() - std = to_numpy(torch.std(avg_atom_inter_es)).item() + theory = torch.cat(theory_list, dim=0) # [total_n_graphs] + # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + # std = to_numpy(torch.std(avg_atom_inter_es)).item() + mean = scatter_mean(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) + std = scatter_std(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) std = _check_non_zero(std) return mean, std @@ -215,10 +220,11 @@ def _compute_mean_std_atomic_inter_energy( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: + theory = batch.theory node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), theory] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes return atom_energies @@ -232,23 +238,38 @@ def compute_mean_rms_energy_forces( atom_energy_list = [] forces_list = [] + theory_list = [] + theory_batch = [] for batch in data_loader: + theory = batch.theory node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), theory] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + theory_list.append(theory) + theory_batch.append(theory[batch.batch]) atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + theory = torch.cat(theory_list, dim=0) # [total_n_graphs] + theory_batch = torch.cat(theory_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + mean = scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1) + print("theory", theory_batch.shape) + print("forces", forces.shape) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=theory_batch, dim=0).mean(-1) + ) + ).item() rms = _check_non_zero(rms) return mean, rms @@ -258,10 +279,11 @@ def _compute_mean_rms_energy_forces( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: + theory = batch.theory node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), theory] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } forces = batch.forces # {[n_graphs*n_atoms,3], } @@ -292,17 +314,20 @@ def compute_statistics( atom_energy_list = [] forces_list = [] num_neighbors = [] + theory_list = [] for batch in data_loader: + theory = batch.theory node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), theory] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + theory_list.append(theory) # {[n_graphs], } _, receivers = batch.edge_index _, counts = torch.unique(receivers, return_counts=True) @@ -310,9 +335,15 @@ def compute_statistics( atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + theory = torch.cat(theory_list, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + mean = scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1) + # do the mean for each theory + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + rms = to_numpy( + torch.sqrt(scatter_mean(src=torch.square(forces), index=theory, dim=0)) + ).item() avg_num_neighbors = torch.mean( torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index e095ed3e..e20fcbf3 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -10,6 +10,7 @@ import os from typing import Dict, List, Optional, Tuple +import numpy as np import torch import torch.distributed from prettytable import PrettyTable @@ -40,7 +41,7 @@ def get_dataset_from_xyz( charges_key: str = "charges", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" - atomic_energies_dict, all_train_configs = data.load_from_xyz( + atomic_energies_dict, all_train_configs, theories = data.load_from_xyz( file_path=train_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -55,7 +56,7 @@ def get_dataset_from_xyz( f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" ) if valid_path is not None: - _, valid_configs = data.load_from_xyz( + _, valid_configs, _ = data.load_from_xyz( file_path=valid_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -80,7 +81,7 @@ def get_dataset_from_xyz( test_configs = [] if test_path is not None: - _, all_test_configs = data.load_from_xyz( + _, all_test_configs, _ = data.load_from_xyz( file_path=test_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -97,6 +98,7 @@ def get_dataset_from_xyz( return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), atomic_energies_dict, + theories, ) @@ -115,7 +117,7 @@ def get_config_type_weights(ct_weights): return config_type_weights -def get_atomic_energies(E0s, train_collection, z_table) -> dict: +def get_atomic_energies(E0s, train_collection, z_table, theories) -> dict: if E0s is not None: logging.info( "Atomic Energies not in training file, using command line argument E0s" @@ -128,7 +130,7 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: try: assert train_collection is not None atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table + train_collection, z_table, theories ) except Exception as e: raise RuntimeError( @@ -215,6 +217,19 @@ def custom_key(key): return (2, key) +def dict_to_array(data): + unique_keys = set() + for inner_dict in data.values(): + unique_keys.update(inner_dict.keys()) + sorted_keys = sorted(unique_keys) + result_array = np.zeros((len(data), len(sorted_keys))) + for default_index, (_, inner_dict) in enumerate(data.items()): + for key, value in inner_dict.items(): + key_index = sorted_keys.index(key) + result_array[default_index][key_index] = value + return np.squeeze(result_array) + + class LRScheduler: def __init__(self, optimizer, args) -> None: self.scheduler = args.scheduler diff --git a/tests/test_models.py b/tests/test_models.py index 18ef536b..5eb9f516 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -196,3 +196,56 @@ def test_energy_dipole_mace(): np.array(rot @ output["dipole"][0].detach().numpy()), output["dipole"][1].detach().numpy(), ) + + +def test_mace_multi_reference(): + atomic_energies = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("96x0e + 96x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + distance_transform=True, + pair_repulsion=True, + correlation=3, + theories=["Default", "dft"], + # radial_type="chebyshev", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.1], + ) + model = modules.ScaleShiftMACE(**model_config) + model_compiled = jit.compile(model) + config.theory = "Default" + config_rotated.theory = "dft" + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=3.0, theories=["Default", "dft"] + ) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0, theories=["Default", "dft"] + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert output2["energy"].shape[0] == 2 diff --git a/tests/test_modules.py b/tests/test_modules.py index a5166bd2..5539ceb1 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -106,7 +106,7 @@ def test_atomic_energies(self): ) batch = next(iter(data_loader)) - energies = energies_block(batch.node_attrs) + energies = energies_block(batch.node_attrs).squeeze(-1) out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") out = to_numpy(out) assert np.allclose(out, np.array([5.0, 5.0])) From 7d4a3302db7410b510fa3fe28d04818651332dfd Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 25 Mar 2024 13:40:26 +0000 Subject: [PATCH 006/101] fix the scale shift --- mace/modules/blocks.py | 24 ++++++++++++++---------- tests/test_foundations.py | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 94216276..68b126f6 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -142,9 +142,6 @@ def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] ) -> torch.Tensor: # [..., ] - print("self.atomic_energies.T", torch.atleast_2d(self.atomic_energies).T) - print("self.atomic_energies.T", self.atomic_energies.T.shape) - print("x", x.shape) return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) def __repr__(self): @@ -752,13 +749,20 @@ def __init__(self, scale: float, shift: float): ) def forward(self, x: torch.Tensor, theory: torch.Tensor) -> torch.Tensor: - print("theory", theory.shape) - print("x", x.shape) - print("self.scale", self.scale.shape) - print("self.shift", self.shift.shape) - return self.scale[theory] * x + self.shift[theory] + return ( + torch.atleast_1d(self.scale)[theory] * x + + torch.atleast_1d(self.shift)[theory] + ) def __repr__(self): - return ( - f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" + formatted_scale = ( + ", ".join([f"{x:.4f}" for x in self.scale]) + if self.scale.numel() > 1 + else f"{self.scale.item():.4f}" + ) + formatted_shift = ( + ", ".join([f"{x:.4f}" for x in self.shift]) + if self.shift.numel() > 1 + else f"{self.shift.item():.4f}" ) + return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 3a5a8b3d..dcb7d360 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -35,7 +35,7 @@ atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) -@pytest.skip("Problem with the float type", allow_module_level=True) +# @pytest.skip("Problem with the float type", allow_module_level=True) def test_foundations(): # Create MACE model model_config = dict( From 4d0f83207b76f3fb0caf835ccdcb1f43ff9a965d Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 26 Mar 2024 15:21:31 +0000 Subject: [PATCH 007/101] fix the E0, statistics, add further point sampling --- .gitignore | 2 + mace/calculators/mace.py | 14 +- mace/cli/fine_tuning_select.py | 254 +++++++++++++++++++++++++++++++++ mace/cli/run_train.py | 27 +++- mace/data/atomic_data.py | 8 +- mace/data/utils.py | 9 +- mace/modules/blocks.py | 26 +++- mace/modules/irreps_tools.py | 12 ++ mace/modules/models.py | 26 ++-- mace/modules/utils.py | 8 +- mace/tools/arg_parser.py | 7 + mace/tools/utils.py | 133 ++++++++++++----- setup.cfg | 1 + tests/test_foundations.py | 69 +++++++++ 14 files changed, 528 insertions(+), 68 deletions(-) create mode 100644 mace/cli/fine_tuning_select.py diff --git a/.gitignore b/.gitignore index 801dc9ce..a7bcd07f 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ dist/ # DS_Store .DS_Store +*.model +*.pt diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 68922b4d..3118000e 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -128,6 +128,10 @@ def __init__( [int(z) for z in self.models[0].atomic_numbers] ) self.charges_key = charges_key + try: + self.theories = self.models[0].theories + except: + self.theories = ["Default"] model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -193,7 +197,10 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max + config, + z_table=self.z_table, + cutoff=self.r_max, + theories=self.theories, ) ], batch_size=1, @@ -300,7 +307,10 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max + config, + z_table=self.z_table, + cutoff=self.r_max, + theories=self.theories, ) ], batch_size=1, diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py new file mode 100644 index 00000000..47780b72 --- /dev/null +++ b/mace/cli/fine_tuning_select.py @@ -0,0 +1,254 @@ +########################################################################################### +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import logging +import typing as t + + +import ase.data +import ase.io +import numpy as np +import torch + +from mace.calculators import MACECalculator, mace_mp +from tqdm import tqdm + +from mace import data +import pandas as pd +from mace.tools import torch_geometric, torch_tools, utils + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--configs_pt", + help="path to XYZ configurations for the pretraining", + required=True, + ) + parser.add_argument( + "--configs_ft", + help="path to XYZ configurations for the finetuning", + required=True, + ) + parser.add_argument( + "--num_samples", + help="number of samples to select for the pretraining", + type=int, + ) + parser.add_argument("--model", help="path to model", required=True) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--descriptors", help="path to descriptors", required=False, default=None + ) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + parser.add_argument( + "--theory_pt", + help="level of theory for the pretraining set", + type=str, + default=None, + ) + parser.add_argument( + "--theory_ft", + help="level of theory for the finetuning set", + type=str, + default=None, + ) + return parser.parse_args() + + +def calculate_descriptors( + atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict +) -> None: + print("Calculating descriptors") + for mol in tqdm(atoms): + descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) + # average descriptors over atoms for each element + descriptors_dict = { + element: np.mean(descriptors[mol.symbols == element], axis=0) + for element in np.unique(mol.symbols) + } + mol.info["mace_descriptors"] = descriptors_dict + + +def filter_atoms( + atoms: ase.Atoms, element_subset: list[str], filtering_type: str +) -> bool: + """ + Filters atoms based on the provided filtering type and element subset. + + Parameters: + atoms (ase.Atoms): The atoms object to filter. + element_subset (list): The list of elements to consider during filtering. + filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'. + 'none' - No filtering is applied. + 'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. + 'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise. + 'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + + Returns: + bool: True if the atoms pass the filter, False otherwise. + """ + if filtering_type == "none": + return True + elif filtering_type == "combinations": + atom_symbols = np.unique(atoms.symbols) + return all( + [x in element_subset for x in atom_symbols] + ) # atoms must *only* contain elements in the subset + elif filtering_type == "exclusive": + atom_symbols = set([x for x in atoms.symbols]) + return atom_symbols == set(element_subset) + elif filtering_type == "inclusive": + atom_symbols = np.unique(atoms.symbols) + return all( + [x in atom_symbols for x in element_subset] + ) # atoms must *at least* contain elements in the subset + else: + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." + ) + + +class FPS: + def __init__(self, atoms_list: list[ase.Atoms], n_samples: int): + self.n_samples = n_samples + self.atoms_list = atoms_list + # start from a random configuration + self.list_index = [np.random.randint(0, len(atoms_list))] + + def run(self) -> list[int]: + """ + Run the farthest point sampling algorithm. + """ + for _ in range(max(self.n_samples, len(self.atoms_list)) - 1): + self.update() + return self.list_index + + def update(self) -> list[int]: + """ + Compute the farthest point sampling for the index-th configuration. + """ + distance_matrix = self.compute_distance(self.list_index[-1]) + index_next = np.argmax(np.mean(distance_matrix, axis=1)) + self.list_index.append(index_next) + + def compute_distance(self, index: int) -> np.ndarray: + """ + Compute the distance matrix between the descriptor of the index-th configuration and all the other configurations. + """ + descriptors_filtered = self.filter_species(index) + # compute the distance matrix + distance_matrix = np.zeros((len(self.atoms_list), len(descriptors_filtered))) + descriptors_atoms_index = self.atoms_list[index].info["mace_descriptors"] + for zi, z in enumerate(descriptors_filtered): + distance_matrix[:, zi] = np.nan_to_num( + np.linalg.norm( + descriptors_filtered[z] - descriptors_atoms_index[z], + axis=1, + ) + ) + # put inf to zeros + return distance_matrix + + def filter_species(self, index: int) -> list[ase.Atoms]: + """ + Filter the configurations based on the species of the index-th configuration. + """ + species_index = np.unique(self.atoms_list[index].symbols) + descriptors_species = {z: [] for z in species_index} + descriptors_index = self.atoms_list[index].info["mace_descriptors"] + for i, atoms in enumerate(self.atoms_list): + descriptors_atoms = atoms.info["mace_descriptors"] + for z in species_index: + descriptors_species[z].append( + descriptors_atoms[z] + if z in descriptors_atoms + else np.full_like(descriptors_index[z], np.nan) + ) + for z in species_index: + descriptors_species[z] = np.array(descriptors_species[z]) + return descriptors_species + + +def main(): + args = parse_args() + if args.model in ["small", "medium", "large"]: + calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) + else: + calc = MACECalculator( + model_paths=args.model, device=args.device, default_dtype=args.default_dtype + ) + atoms_list_ft = ase.io.read(args.configs_ft, index=":") + all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) + print( + "Filtering configurations based on the finetuning set," + f"filtering type: combinations, elements: {all_species_ft}" + ) + + if args.descriptors is not None: + print("Loading descriptors") + descriptors = np.load(args.descriptors) + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + for i, atoms in enumerate(atoms_list_pt): + atoms.arrays["mace_descriptors"] = descriptors[i] + print( + "Filtering configurations based on the finetuning set," + f"filtering type: combinations, elements: {all_species_ft}" + ) + atoms_list_pt = [ + x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") + ] + + else: + print("Calculating descriptors for the pretraining set") + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt = [ + x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") + ] + calculate_descriptors(atoms_list_pt, calc, None) + if args.num_samples < len(atoms_list_pt): + print("Selecting configurations using Farthest Point Sampling") + fps_pt = FPS(atoms_list_pt, args.num_samples) + idx_pt = fps_pt.run() + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + for atoms in atoms_list_pt: + # del atoms.info["mace_descriptors"] + atoms.info["pretrained"] = True + if args.theory_pt is not None: + atoms.info["theory"] = args.theory_pt + + print("Saving the selected configurations") + ase.io.write(args.output, atoms_list_pt, format="extxyz") + print("Saving a combined XYZ file") + for atoms in atoms_list_ft: + if args.theory_ft is not None: + atoms.info["theory"] = args.theory_ft + atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft + ase.io.write( + args.output.replace(".xyz", "_combined.xyz"), atoms_fps_pt_ft, format="extxyz" + ) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1e37304c..9ae73fc5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -106,10 +106,26 @@ def main() -> None: dipole_key=args.dipole_key, charges_key=args.charges_key, ) + if args.theories is not None: + args.theories = ast.literal_eval(args.theories) + assert set(theories) == set(args.theories), ( + "Theories from command line and data do not match," + f"{set(theories)} != {set(args.theories)}" + ) + logging.info( + "Using theories from command line argument," + f" theories used: {args.theories}" + ) + theories = args.theories + else: + logging.info( + "Using theories extracted from data files," + f" theories used: {theories}" + ) logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]" + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," ) else: atomic_energies_dict = None @@ -414,6 +430,7 @@ def main() -> None: atomic_inter_shift=0.0, radial_MLP=ast.literal_eval(args.radial_MLP), radial_type=args.radial_type, + theories=theories, ) elif args.model == "ScaleShiftMACE": model = modules.ScaleShiftMACE( @@ -428,6 +445,7 @@ def main() -> None: atomic_inter_shift=args.mean, radial_MLP=ast.literal_eval(args.radial_MLP), radial_type=args.radial_type, + theories=theories, ) elif args.model == "ScaleShiftBOTNet": model = modules.ScaleShiftBOTNet( @@ -506,7 +524,6 @@ def main() -> None: max_L=args.max_L, ) model.to(device) - print(model) # Optimizer decay_interactions = {} @@ -705,7 +722,9 @@ def main() -> None: if args.train_file.endswith(".xyz"): for name, subset in collections.tests: test_sets[name] = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, theories=theories + ) for config in subset ] elif not args.multi_processed_test: @@ -738,7 +757,7 @@ def main() -> None: test_set, batch_size=args.valid_batch_size, shuffle=(test_sampler is None), - drop_last=test_set.drop_last, + drop_last=False, num_workers=args.num_workers, pin_memory=args.pin_memory, ) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 7ed172ba..45bbc969 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -125,9 +125,11 @@ def from_config( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), num_classes=len(z_table), ) - - theory = torch.tensor(theories.index(config.theory), dtype=torch.long) - print("theory", theory) + try: + theory = torch.tensor(theories.index(config.theory), dtype=torch.long) + except: + print(f"Theory {config.theory} not found in {theories}") + theory = torch.tensor(0, dtype=torch.long) cell = ( torch.tensor(config.cell, dtype=torch.get_default_dtype()) diff --git a/mace/data/utils.py b/mace/data/utils.py index 2ccdc136..6849d989 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -274,9 +274,12 @@ def compute_average_E0s( """ len_train = len(collections_train) len_zs = len(z_table) - A = np.zeros((len_train, len_zs)) - B = np.zeros(len_train) + atomic_energies_dict = {} for theory in theories: + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + if theory not in atomic_energies_dict: + atomic_energies_dict[theory] = {} for i in range(len_train): if collections_train[i].theory != theory: continue @@ -285,14 +288,12 @@ def compute_average_E0s( A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) try: E0s = np.linalg.lstsq(A, B, rcond=None)[0] - atomic_energies_dict = {} for i, z in enumerate(z_table.zs): atomic_energies_dict[theory][z] = E0s[i] except np.linalg.LinAlgError: logging.warning( "Failed to compute E0s using least squares regression, using the same for all atoms" ) - atomic_energies_dict = {} for i, z in enumerate(z_table.zs): atomic_energies_dict[theory][z] = 0.0 return atomic_energies_dict diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 68b126f6..dd823f2d 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -16,6 +16,7 @@ from .irreps_tools import ( linear_out_irreps, + mask_theory, reshape_irreps, tp_out_irreps_with_instructions, ) @@ -48,7 +49,9 @@ def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")) super().__init__() self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + def forward( + self, x: torch.Tensor, theories: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -60,16 +63,26 @@ def __init__( MLP_irreps: o3.Irreps, gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), + num_theories: int = 1, ): super().__init__() self.hidden_irreps = MLP_irreps + self.num_theories = num_theories self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + def forward( + self, x: torch.Tensor, theories: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) - return self.linear_2(x) # [n_nodes, 1] + if ( + hasattr(self, "num_theories") + and self.num_theories > 1 + and theories is not None + ): + x = mask_theory(x, theories, self.num_theories) + return self.linear_2(x) # [n_nodes, len(theories)] @compile_mode("script") @@ -145,7 +158,12 @@ def forward( return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) def __repr__(self): - formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) + formatted_energies = ", ".join( + [ + "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" + for group in torch.atleast_2d(self.atomic_energies) + ] + ) return f"{self.__class__.__name__}(energies=[{formatted_energies}])" diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 642f3fa8..4e7ce9f1 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -84,3 +84,15 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: field = field.reshape(batch, mul, d) out.append(field) return torch.cat(out, dim=-1) + + +def mask_theory( + x: torch.Tensor, theory: torch.Tensor, num_theories: int +) -> torch.Tensor: + mask = torch.zeros( + x.shape[0], x.shape[1] // num_theories, num_theories, device=x.device + ) + idx = torch.arange(mask.shape[0], device=x.device) + mask[idx, :, theory] = 1 + mask = mask.permute(0, 2, 1).reshape(x.shape) + return x * mask diff --git a/mace/modules/models.py b/mace/modules/models.py index a4d12780..c6cba12c 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -73,6 +73,7 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) + self.theories = theories if isinstance(correlation, int): correlation = [correlation] * num_interactions # Embedding @@ -166,9 +167,10 @@ def __init__( self.readouts.append( NonLinearReadoutBlock( hidden_irreps_out, - (len(theories) * o3.Irreps("16x0e")).simplify(), + (len(theories) * MLP_irreps).simplify(), gate, o3.Irreps(f"{len(theories)}x0e"), + len(theories), ) ) else: @@ -191,6 +193,7 @@ def forward( print("theory", data["theory"]) num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 + node_theories = data["theory"][data["batch"]] displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -212,9 +215,9 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["theory"][data["batch"]] + num_atoms_arange, node_theories ] - print("node e0", node_e0.shape) + # print("node e0", node_e0.shape) e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs ) # [n_graphs, n_theories] @@ -262,8 +265,8 @@ def forward( node_attrs=data["node_attrs"], ) node_feats_list.append(node_feats) - node_energies = readout(node_feats)[ - num_atoms_arange, data["theory"][data["batch"]] + node_energies = readout(node_feats, node_theories)[ + num_atoms_arange, node_theories ] # [n_nodes, len(theories)] energy = scatter_sum( src=node_energies, @@ -334,6 +337,7 @@ def forward( data["node_attrs"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 num_atoms_arange = torch.arange(data["positions"].shape[0]) + node_theories = data["theory"][data["batch"]] displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -355,7 +359,7 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["theory"][data["batch"]] + num_atoms_arange, node_theories ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs @@ -396,18 +400,18 @@ def forward( ) node_feats_list.append(node_feats) node_es_list.append( - readout(node_feats)[num_atoms_arange, data["theory"][data["batch"]]] + readout(node_feats, node_theories)[num_atoms_arange, node_theories] ) # {[n_nodes, ], } # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) - print("node_es_list", node_es_list) + # print("node_es_list", node_es_list) # Sum over interactions node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - print("node_inter_es", node_inter_es.shape) - node_inter_es = self.scale_shift(node_inter_es, data["theory"][data["batch"]]) + # print("node_inter_es", node_inter_es.shape) + node_inter_es = self.scale_shift(node_inter_es, node_theories) # Sum over nodes in graph inter_e = scatter_sum( @@ -417,7 +421,7 @@ def forward( # Add E_0 and (scaled) interaction energy total_energy = e0 + inter_e node_energy = node_e0 + node_inter_es - print("node_energy", node_energy.shape) + # print("node_energy", node_energy.shape) forces, virials, stress = get_outputs( energy=inter_e, positions=data["positions"], diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 07cf6d7b..0faaf399 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -161,11 +161,11 @@ def get_edge_vectors_and_lengths( def _check_non_zero(std): - if std == 0.0: + if np.any(std == 0): logging.warning( "Standard deviation of the scaling is zero, Changing to no scaling" ) - std = 1.0 + std[std == 0] = 1 return std @@ -263,13 +263,11 @@ def compute_mean_rms_energy_forces( # mean = to_numpy(torch.mean(atom_energies)).item() # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() mean = scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1) - print("theory", theory_batch.shape) - print("forces", forces.shape) rms = to_numpy( torch.sqrt( scatter_mean(src=torch.square(forces), index=theory_batch, dim=0).mean(-1) ) - ).item() + ) rms = _check_non_zero(rms) return mean, rms diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 9f4f3801..6e3b27ce 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -311,6 +311,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=None, required=False, ) + parser.add_argument( + "--theories", + help="List of theories in the training set", + type=str, + default=None, + required=False, + ) parser.add_argument( "--energy_key", help="Key of reference energies in training xyz", diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 9edb85a2..abd293c8 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -8,7 +8,7 @@ import logging import os import sys -from typing import Any, Dict, Iterable, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union import numpy as np import torch @@ -167,6 +167,10 @@ def load_foundations( """ assert model_foundations.r_max == model.r_max z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + try: + foundations_theories = model_foundations.theories + except AttributeError: + foundations_theories = ["Default"] new_z_table = table num_species_foundations = len(z_table.zs) num_channels_foundation = ( @@ -195,22 +199,24 @@ def load_foundations( for j in range(4): # Assuming 4 layers in conv_tp_weights, layer_name = f"layer{j}" if j == 0: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, layer_name + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() ) - .weight[:num_radial, :] - .clone() ) else: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, layer_name - ).weight.clone() + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) ) model.interactions[i].linear.weight = torch.nn.Parameter( @@ -231,29 +237,42 @@ def load_foundations( .clone() / (num_species_foundations / num_species) ** 0.5 ) + else: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + (max_L + 1) ** 2, + num_species_foundations, + num_channels_foundation, + )[:, :, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[ - j - ].weights_max = torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) - - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[ - k - ] = torch.nn.Parameter( + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( model_foundations.products[i] .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] + .weights_max[indices_weights, :, :] .clone() ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) model.products[i].linear.weight = torch.nn.Parameter( model_foundations.products[i].linear.weight.clone() @@ -261,20 +280,64 @@ def load_foundations( if load_readout: # Transferring readouts + # model.readouts[0].linear.weight[ + # : len(foundations_theories) * num_channels_foundation + # ] = torch.nn.Parameter(model_foundations.readouts[0].linear.weight.clone()) + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight.view(num_channels_foundation, -1)[ + :, : len(foundations_theories) + ] = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .clone() + ) model.readouts[0].linear.weight = torch.nn.Parameter( - model_foundations.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight + ) + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + # model.readouts[1].linear_1.weight[:shape_input_1] = torch.nn.Parameter( + # model_foundations.readouts[1].linear_1.weight.clone() + # ) + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight.view(num_channels_foundation, -1)[ + :, :shape_input_1 + ] = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .clone() ) - model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight + ) + shape_input_2 = ( + model_foundations.readouts[1].linear_2.__dict__["irreps_out"].num_irreps + ) + # model.readouts[1].linear_2.weight[:shape_input_2] = torch.nn.Parameter( + # model_foundations.readouts[1].linear_2.weight.clone() + # ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight.view(shape_output_1, -1)[ + :shape_input_1, :shape_input_2 + ] = model_foundations.readouts[1].linear_2.weight.view( + shape_input_1, -1 + ).clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight ) if model_foundations.scale_shift is not None: if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.clone() + scale_shape = model_foundations.scale_shift.scale.shape + model.scale_shift.scale[: (len(scale_shape) + 1)] = ( + model_foundations.scale_shift.scale.clone() + ) if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.clone() + shift_shape = model_foundations.scale_shift.shift.shape + model.scale_shift.shift[: len(shift_shape)] = ( + model_foundations.scale_shift.shift.clone() + ) return model diff --git a/setup.cfg b/setup.cfg index 08c396f4..5a97efe1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ console_scripts = mace_plot_train = mace.cli.plot_train:main mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main + mace_finetuning = mace.cli.fine_tuning_select:main [options.extras_require] wandb = wandb diff --git a/tests/test_foundations.py b/tests/test_foundations.py index dcb7d360..ad7f7aab 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -92,3 +92,72 @@ def test_foundations(): forces_loaded = model_loaded(batch)["forces"] forces = model(batch)["forces"] assert torch.allclose(forces, forces_loaded) + + +def test_multi_reference(): + config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + forces=molecule("H2COH").positions, + energy=-1.5, + charges=molecule("H2COH").numbers, + dipole=np.array([-1.5, 1.5, 2.0]), + theory="MP2", + ) + table = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=61, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.0], + theories=["MP2", "DFT"], + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + model_loaded = load_foundations( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=6.0, theories=["MP2", "DFT"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch)["forces"] + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + atoms = molecule("H2COH") + atoms.calc = calc_foundation + forces = atoms.get_forces() + assert np.allclose( + forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 + ) From 9cebf36b160e31a4d8305668dbcb32130c070e4b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:05:11 +0000 Subject: [PATCH 008/101] skip the descriptors if not required --- mace/cli/fine_tuning_select.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 47780b72..61163fa4 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -36,6 +36,8 @@ def parse_args() -> argparse.Namespace: "--num_samples", help="number of samples to select for the pretraining", type=int, + required=False, + default=None, ) parser.add_argument("--model", help="path to model", required=True) parser.add_argument("--output", help="output path", required=True) @@ -141,7 +143,7 @@ def run(self) -> list[int]: """ Run the farthest point sampling algorithm. """ - for _ in range(max(self.n_samples, len(self.atoms_list)) - 1): + for _ in range(min(self.n_samples, len(self.atoms_list)) - 1): self.update() return self.list_index @@ -201,6 +203,7 @@ def main(): ) atoms_list_ft = ase.io.read(args.configs_ft, index=":") all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) + print( "Filtering configurations based on the finetuning set," f"filtering type: combinations, elements: {all_species_ft}" @@ -221,16 +224,19 @@ def main(): ] else: - print("Calculating descriptors for the pretraining set") atoms_list_pt = ase.io.read(args.configs_pt, index=":") atoms_list_pt = [ x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] - calculate_descriptors(atoms_list_pt, calc, None) - if args.num_samples < len(atoms_list_pt): + + if args.num_samples is not None and args.num_samples < len(atoms_list_pt): + if args.descriptors is None: + print("Calculating descriptors for the pretraining set") + calculate_descriptors(atoms_list_pt, calc, None) print("Selecting configurations using Farthest Point Sampling") fps_pt = FPS(atoms_list_pt, args.num_samples) idx_pt = fps_pt.run() + print(f"Selected {len(idx_pt)} configurations") atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] From a83af14db473a8f5dde97f5b9978966e25554a82 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 26 Mar 2024 18:19:38 +0000 Subject: [PATCH 009/101] fix loading mace models --- mace/cli/run_train.py | 2 +- mace/tools/scripts_utils.py | 2 ++ mace/tools/utils.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9ae73fc5..40ed6146 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -427,7 +427,7 @@ def main() -> None: ], MLP_irreps=o3.Irreps(args.MLP_irreps), atomic_inter_scale=args.std, - atomic_inter_shift=0.0, + atomic_inter_shift=[0.0] * len(theories), radial_MLP=ast.literal_eval(args.radial_MLP), radial_type=args.radial_type, theories=theories, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index e20fcbf3..76dd2b38 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -218,6 +218,8 @@ def custom_key(key): def dict_to_array(data): + if not all(isinstance(value, dict) for value in data.values()): + return np.array(list(data.values())) unique_keys = set() for inner_dict in data.values(): unique_keys.update(inner_dict.keys()) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index abd293c8..7e43f521 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -161,6 +161,7 @@ def load_foundations( use_shift=False, use_scale=True, max_L=2, + max_ell=3, ): """ Load the foundations of a model into a model for fine-tuning. @@ -242,7 +243,7 @@ def load_foundations( model_foundations.interactions[i] .skip_tp.weight.reshape( num_channels_foundation, - (max_L + 1) ** 2, + (max_ell + 1), num_species_foundations, num_channels_foundation, )[:, :, indices_weights, :] From 8c65cf261b00f1ea6bedec13239ef9348843ed8f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:35:58 +0000 Subject: [PATCH 010/101] start from foundation model on all theories --- mace/tools/utils.py | 71 ++++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 7e43f521..cd706dd7 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -172,6 +172,7 @@ def load_foundations( foundations_theories = model_foundations.theories except AttributeError: foundations_theories = ["Default"] + model_theories = model.theories new_z_table = table num_species_foundations = len(z_table.zs) num_channels_foundation = ( @@ -189,7 +190,10 @@ def load_foundations( .clone() / (num_species_foundations / num_species) ** 0.5 ) - + if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": + model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( + model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() + ) for i in range(int(model.num_interactions)): model.interactions[i].linear_up.weight = torch.nn.Parameter( model_foundations.interactions[i].linear_up.weight.clone() @@ -281,50 +285,64 @@ def load_foundations( if load_readout: # Transferring readouts - # model.readouts[0].linear.weight[ - # : len(foundations_theories) * num_channels_foundation - # ] = torch.nn.Parameter(model_foundations.readouts[0].linear.weight.clone()) model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() - model_readouts_zero_linear_weight.view(num_channels_foundation, -1)[ - :, : len(foundations_theories) - ] = ( + # model_readouts_zero_linear_weight.view(num_channels_foundation, -1)[ + # :, : len(foundations_theories) + # ] = ( + # model_foundations.readouts[0] + # .linear.weight.view(num_channels_foundation, -1) + # .clone() + # ) + model_readouts_zero_linear_weight = ( model_foundations.readouts[0] .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_theories)) + .flatten() .clone() ) model.readouts[0].linear.weight = torch.nn.Parameter( model_readouts_zero_linear_weight ) + shape_input_1 = ( model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps ) shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - # model.readouts[1].linear_1.weight[:shape_input_1] = torch.nn.Parameter( - # model_foundations.readouts[1].linear_1.weight.clone() - # ) model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() - model_readouts_one_linear_1_weight.view(num_channels_foundation, -1)[ - :, :shape_input_1 - ] = ( + # model_readouts_one_linear_1_weight.view(num_channels_foundation, -1)[ + # :, :shape_input_1 + # ] = ( + # model_foundations.readouts[1] + # .linear_1.weight.view(num_channels_foundation, -1) + # .clone() + # ) + model_readouts_one_linear_1_weight = ( model_foundations.readouts[1] .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_theories)) + .flatten() .clone() ) model.readouts[1].linear_1.weight = torch.nn.Parameter( model_readouts_one_linear_1_weight ) - shape_input_2 = ( - model_foundations.readouts[1].linear_2.__dict__["irreps_out"].num_irreps - ) - # model.readouts[1].linear_2.weight[:shape_input_2] = torch.nn.Parameter( - # model_foundations.readouts[1].linear_2.weight.clone() + # shape_input_2 = ( + # model_foundations.readouts[1].linear_2.__dict__["irreps_out"].num_irreps # ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() - model_readouts_one_linear_2_weight.view(shape_output_1, -1)[ - :shape_input_1, :shape_input_2 - ] = model_foundations.readouts[1].linear_2.weight.view( - shape_input_1, -1 - ).clone() / ( + # model_readouts_one_linear_2_weight.view(shape_output_1, -1)[ + # :shape_input_1, :shape_input_2 + # ] = model_foundations.readouts[1].linear_2.weight.view( + # shape_input_1, -1 + # ).clone() / ( + # ((shape_input_1) / (shape_output_1)) ** 0.5 + # ) + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_theories), len(model_theories) + ).flatten().clone() / ( ((shape_input_1) / (shape_output_1)) ** 0.5 ) model.readouts[1].linear_2.weight = torch.nn.Parameter( @@ -332,10 +350,9 @@ def load_foundations( ) if model_foundations.scale_shift is not None: if use_scale: - scale_shape = model_foundations.scale_shift.scale.shape - model.scale_shift.scale[: (len(scale_shape) + 1)] = ( - model_foundations.scale_shift.scale.clone() - ) + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_theories) + ).clone() if use_shift: shift_shape = model_foundations.scale_shift.shift.shape model.scale_shift.shift[: len(shift_shape)] = ( From 7f07da04539ec187932051f06013c3273168c5b2 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 27 Mar 2024 12:25:58 +0000 Subject: [PATCH 011/101] add nice parsing of multiple theories --- mace/cli/run_train.py | 68 +++++++---- mace/data/hdf5_dataset.py | 12 +- mace/tools/train.py | 247 ++++++++++++-------------------------- 3 files changed, 130 insertions(+), 197 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 40ed6146..1a88a7b7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -193,21 +193,29 @@ def main() -> None: ) for config in collections.train ] - valid_set = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, theories=theories - ) - for config in collections.valid - ] + valid_sets = {theory: [] for theory in theories} + for theory in theories: + valid_sets[theory] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, theories=theories + ) + for config in collections.valid + if config.theory == theory + ] + elif args.train_file.endswith(".h5"): - train_set = data.HDF5Dataset(args.train_file, r_max=args.r_max, z_table=z_table) - valid_set = data.HDF5Dataset(args.valid_file, r_max=args.r_max, z_table=z_table) + train_set = data.HDF5Dataset( + args.train_file, r_max=args.r_max, z_table=z_table, theories=theories + ) + valid_set = data.HDF5Dataset( + args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories + ) else: # This case would be for when the file path is to a directory of multiple .h5 files train_set = data.dataset_from_sharded_hdf5( - args.train_file, r_max=args.r_max, z_table=z_table + args.train_file, r_max=args.r_max, z_table=z_table, theories=theories ) valid_set = data.dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table + args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories ) train_sampler, valid_sampler = None, None @@ -238,15 +246,28 @@ def main() -> None: pin_memory=args.pin_memory, num_workers=args.num_workers, ) - valid_loader = torch_geometric.dataloader.DataLoader( - dataset=valid_set, - batch_size=args.valid_batch_size, - sampler=valid_sampler, - shuffle=(valid_sampler is None), - drop_last=False, - pin_memory=args.pin_memory, - num_workers=args.num_workers, - ) + valid_loaders = {theories[i]: None for i in range(len(theories))} + if not isinstance(valid_sets, dict): + valid_sets = {"Default": valid_sets} + for theory, valid_set in valid_sets.items(): + valid_loaders[theory] = torch_geometric.dataloader.DataLoader( + dataset=valid_set, + batch_size=args.valid_batch_size, + sampler=valid_sampler, + shuffle=(valid_sampler is None), + drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + ) + # valid_loader = torch_geometric.dataloader.DataLoader( + # dataset=valid_set, + # batch_size=args.valid_batch_size, + # sampler=valid_sampler, + # shuffle=(valid_sampler is None), + # drop_last=False, + # pin_memory=args.pin_memory, + # num_workers=args.num_workers, + # ) # loss_fn: torch.nn.Module = get_loss_fn( # args.loss, @@ -689,7 +710,7 @@ def main() -> None: model=model, loss_fn=loss_fn, train_loader=train_loader, - valid_loader=valid_loader, + valid_loaders=valid_loaders, optimizer=optimizer, lr_scheduler=lr_scheduler, checkpoint_handler=checkpoint_handler, @@ -715,8 +736,9 @@ def main() -> None: all_data_loaders = { "train": train_loader, - "valid": valid_loader, } + for theory, valid_loader in valid_loaders.items(): + all_data_loaders[theory] = valid_loader test_sets = {} if args.train_file.endswith(".xyz"): @@ -732,14 +754,14 @@ def main() -> None: for test_file in test_files: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table + test_file, r_max=args.r_max, z_table=z_table, theories=theories ) else: test_folders = glob(args.test_dir + "/*") for folder in test_folders: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table + folder, r_max=args.r_max, z_table=z_table, theories=theories ) for test_name, test_set in test_sets.items(): diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 5057fd7f..6193c409 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -58,6 +58,7 @@ def __getitem__(self, index): dipole=unpack_value(subgrp["dipole"][()]), charges=unpack_value(subgrp["charges"][()]), weight=unpack_value(subgrp["weight"][()]), + theory=unpack_value(subgrp["theory"][()]), energy_weight=unpack_value(subgrp["energy_weight"][()]), forces_weight=unpack_value(subgrp["forces_weight"][()]), stress_weight=unpack_value(subgrp["stress_weight"][()]), @@ -67,16 +68,21 @@ def __getitem__(self, index): cell=unpack_value(subgrp["cell"][()]), ) atomic_data = AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max + config, + z_table=self.z_table, + cutoff=self.r_max, + theories=self.kwargs.get("theories", ["Default"]), ) return atomic_data -def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): +def dataset_from_sharded_hdf5( + files: List, z_table: AtomicNumberTable, r_max: float, **kwargs +): files = glob(files + "/*") datasets = [] for file in files: - datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) full_dataset = ConcatDataset(datasets) return full_dataset diff --git a/mace/tools/train.py b/mace/tools/train.py index 9bd7cdf4..ad683316 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -41,7 +41,14 @@ class SWAContainer: loss_fn: torch.nn.Module -def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): +def valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch=None, + valid_loader_name="Default", +): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) @@ -49,7 +56,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -59,7 +66,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -69,37 +76,37 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" + f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" ) @@ -107,7 +114,7 @@ def train( model: torch.nn.Module, loss_fn: torch.nn.Module, train_loader: DataLoader, - valid_loader: DataLoader, + valid_loaders: Dict[str, DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, start_epoch: int, @@ -142,14 +149,19 @@ def train( epoch = start_epoch # log validation loss before _any_ training - valid_loss, eval_metrics = evaluate( - model=model, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - valid_err_log(valid_loss, eval_metrics, logger, log_errors) + valid_loss = 0.0 + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_theory, eval_metrics = evaluate( + model=model, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_loss += valid_loss_theory + valid_err_log( + valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name + ) while epoch < max_num_epochs: # LR scheduler and SWA update @@ -199,184 +211,77 @@ def train( ema.average_parameters() if ema is not None else nullcontext() ) with param_context: - valid_loss, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch) - if ema is not None: - with ema.average_parameters(): - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, + valid_loss = 0.0 + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_theory, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_loss += valid_loss_theory + valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, ) - else: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - if rank == 0: - eval_metrics["mode"] = "eval" - eval_metrics["epoch"] = epoch - logger.log(eval_metrics) - if log_errors == "PerAtomRMSE": - eval_metrics["mode"] = "eval" - eval_metrics["epoch"] = epoch - logger.log(eval_metrics) - if log_errors == "PerAtomRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_virials_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" - ) - elif log_errors == "TotalRMSE": - error_e = eval_metrics["rmse_e"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" - ) - elif log_errors == "PerAtomMAE": - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" - ) - elif log_errors == "TotalMAE": - error_e = eval_metrics["mae_e"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" - ) - elif log_errors == "DipoleRMSE": - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" - ) - elif log_errors == "EnergyDipoleRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" - ) if log_wandb: wandb_log_dict = { + "theory": valid_loader_name, "epoch": epoch, "valid_loss": valid_loss, "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], "valid_rmse_f": eval_metrics["rmse_f"], } wandb.log(wandb_log_dict) + if valid_loss >= lowest_loss: patience_counter += 1 - if swa is not None: - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" - ) - epoch = swa.start - elif patience_counter >= patience: + if patience_counter >= patience and epoch < swa.start: logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) - elif log_errors == "TotalRMSE": - error_e = eval_metrics["rmse_e"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 + epoch = swa.start + elif patience_counter >= patience and epoch >= swa.start: logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"Stopping optimization after {patience_counter} epochs without improvement" ) - elif log_errors == "PerAtomMAE": - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" - ) - elif log_errors == "TotalMAE": - error_e = eval_metrics["mae_e"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" - ) - elif log_errors == "DipoleRMSE": - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" - ) - elif log_errors == "EnergyDipoleRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" - ) - if log_wandb: - wandb_log_dict = { - "epoch": epoch, - "valid_loss": valid_loss, - "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], - "valid_rmse_f": eval_metrics["rmse_f"], - } - wandb.log(wandb_log_dict) - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" - ) - epoch = swa.start - elif patience_counter >= patience: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" + break + if ema is not None: + with ema.average_parameters(): + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, ) - break + else: checkpoint_handler.save( state=CheckpointState(model, optimizer, lr_scheduler), epochs=epoch, - keep_last=keep_last, + keep_last=True, ) - else: - lowest_loss = valid_loss - patience_counter = 0 - if ema is not None: - with ema.average_parameters(): - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - # keep_last = False - else: + else: + lowest_loss = valid_loss + patience_counter = 0 + if ema is not None: + with ema.average_parameters(): checkpoint_handler.save( state=CheckpointState(model, optimizer, lr_scheduler), epochs=epoch, keep_last=keep_last, ) - # keep_last = False + # keep_last = False + else: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + # keep_last = False if distributed: torch.distributed.barrier() epoch += 1 From 2a7d9243bc66435b1aba6bd6c980ff217abfe5b9 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:05:27 +0000 Subject: [PATCH 012/101] fix wandb logging --- .gitignore | 1 + mace/cli/run_train.py | 5 +++++ mace/modules/utils.py | 12 +++++++----- mace/tools/train.py | 10 ++++++---- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index a7bcd07f..7bf13ea2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ dist/ .DS_Store *.model *.pt +/wandb diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1a88a7b7..0fe47f34 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -690,6 +690,11 @@ def main() -> None: wandb_config = {} args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + args_dict_json = json.dumps(args_dict) for key in args.wandb_log_hypers: wandb_config[key] = args_dict[key] diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 0faaf399..4e5fec6d 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -209,8 +209,10 @@ def compute_mean_std_atomic_inter_energy( theory = torch.cat(theory_list, dim=0) # [total_n_graphs] # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() # std = to_numpy(torch.std(avg_atom_inter_es)).item() - mean = scatter_mean(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) - std = scatter_std(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) + mean = to_numpy( + scatter_mean(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) + ) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1)) std = _check_non_zero(std) return mean, std @@ -262,7 +264,7 @@ def compute_mean_rms_energy_forces( # mean = to_numpy(torch.mean(atom_energies)).item() # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - mean = scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1) + mean = to_numpy(scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1)) rms = to_numpy( torch.sqrt( scatter_mean(src=torch.square(forces), index=theory_batch, dim=0).mean(-1) @@ -336,12 +338,12 @@ def compute_statistics( theory = torch.cat(theory_list, dim=0) # [total_n_graphs] # mean = to_numpy(torch.mean(atom_energies)).item() - mean = scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1) + mean = to_numpy(scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1)) # do the mean for each theory # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() rms = to_numpy( torch.sqrt(scatter_mean(src=torch.square(forces), index=theory, dim=0)) - ).item() + ) avg_num_neighbors = torch.mean( torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) diff --git a/mace/tools/train.py b/mace/tools/train.py index ad683316..3bd61b76 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -212,6 +212,7 @@ def train( ) with param_context: valid_loss = 0.0 + wandb_log_dict = {} for valid_loader_name, valid_loader in valid_loaders.items(): valid_loss_theory, eval_metrics = evaluate( model=model_to_evaluate, @@ -230,14 +231,15 @@ def train( valid_loader_name, ) if log_wandb: - wandb_log_dict = { - "theory": valid_loader_name, + wandb_log_dict[valid_loader_name] = { "epoch": epoch, - "valid_loss": valid_loss, + "valid_loss": valid_loss_theory, "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], "valid_rmse_f": eval_metrics["rmse_f"], } - wandb.log(wandb_log_dict) + + if log_wandb: + wandb.log(wandb_log_dict) if valid_loss >= lowest_loss: patience_counter += 1 From d40e50ffd9eb6de968c094ee5545b52bdbcc04d8 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:24:19 +0000 Subject: [PATCH 013/101] Update run_train.py --- mace/cli/run_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index e36b3860..7eed8696 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -45,7 +45,9 @@ def main() -> None: try: import intel_extension_for_pytorch as ipex except ImportError: - raise ImportError("Error: Intel extension for PyTorch not found, but XPU device was specified") + raise ImportError( + "Error: Intel extension for PyTorch not found, but XPU device was specified" + ) if args.distributed: try: distr_env = DistributedEnvironment() @@ -437,7 +439,6 @@ def main() -> None: args.std = 1.0 logging.info("No scaling selected") elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": - print("args.model", args.model) args.mean, args.std = modules.scaling_classes[args.scaling]( train_loader, atomic_energies ) @@ -603,7 +604,6 @@ def main() -> None: model, optimizer = ipex.optimize(model, optimizer=optimizer) logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train") - lr_scheduler = LRScheduler(optimizer, args) swa: Optional[tools.SWAContainer] = None From ab8e8b47e092d89e5aea82e1996b3075970aeea6 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 1 Apr 2024 14:24:20 +0100 Subject: [PATCH 014/101] add data parrellel multi theory --- mace/cli/preprocess_data.py | 52 ++++++++++++++++++------------------- mace/cli/run_train.py | 46 ++++++++++++++++++++++---------- mace/tools/scripts_utils.py | 26 +++++++++++++++---- 3 files changed, 80 insertions(+), 44 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 23dfd3f2..8af61f6f 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -179,32 +179,32 @@ def multi_train_hdf5(process): for i in processes: i.join() - - logging.info("Computing statistics") - if len(atomic_energies_dict) == 0: - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") - _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] - avg_num_neighbors, mean, std=pool_compute_stats(_inputs) - logging.info(f"Average number of neighbors: {avg_num_neighbors}") - logging.info(f"Mean: {mean}") - logging.info(f"Standard deviation: {std}") - - # save the statistics as a json - statistics = { - "atomic_energies": str(atomic_energies_dict), - "avg_num_neighbors": avg_num_neighbors, - "mean": mean, - "std": std, - "atomic_numbers": str(z_table.zs), - "r_max": args.r_max, - } - - with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 - json.dump(statistics, f) + if args.compute_statistics: + logging.info("Computing statistics") + if len(atomic_energies_dict) == 0: + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic energies: {atomic_energies.tolist()}") + _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] + avg_num_neighbors, mean, std=pool_compute_stats(_inputs) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": str(atomic_energies_dict), + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": str(z_table.zs), + "r_max": args.r_max, + } + + with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 + json.dump(statistics, f) logging.info("Preparing validation set") if args.shuffle: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7eed8696..33c94da3 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -32,6 +32,7 @@ get_dataset_from_xyz, get_files_with_suffix, dict_to_array, + check_folder_subfolder, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import load_foundations @@ -91,7 +92,11 @@ def main() -> None: args.std = statistics["std"] args.avg_num_neighbors = statistics["avg_num_neighbors"] args.compute_avg_num_neighbors = False - args.E0s = statistics["atomic_energies"] + args.E0s = ( + statistics["atomic_energies"] + if not args.E0s.endswith(".json") + else args.E0s + ) # Data preparation if args.train_file.endswith(".xyz"): @@ -222,9 +227,21 @@ def main() -> None: train_set = data.dataset_from_sharded_hdf5( args.train_file, r_max=args.r_max, z_table=z_table, theories=theories ) - valid_set = data.dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories - ) + # check if the folder has subfolders for each theory by opening args.valid_file folder + if check_folder_subfolder(args.valid_file): + valid_sets = {} + for theory in theories: + valid_sets[theory] = data.dataset_from_sharded_hdf5( + os.path.join(args.valid_file, theory), + r_max=args.r_max, + z_table=z_table, + theories=theories, + ) + else: + valid_set = data.dataset_from_sharded_hdf5( + args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories + ) + valid_sets = {"Default": valid_set} train_sampler, valid_sampler = None, None if args.distributed: @@ -236,14 +253,17 @@ def main() -> None: drop_last=True, seed=args.seed, ) - valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) + valid_samplers = {} + for theory, valid_set in valid_sets.items(): + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers[theory] = valid_sampler train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, @@ -261,7 +281,7 @@ def main() -> None: valid_loaders[theory] = torch_geometric.dataloader.DataLoader( dataset=valid_set, batch_size=args.valid_batch_size, - sampler=valid_sampler, + sampler=valid_samplers[theory] if args.distributed else None, shuffle=(valid_sampler is None), drop_last=False, pin_memory=args.pin_memory, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 76dd2b38..1234692e 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -6,6 +6,7 @@ import ast import dataclasses +import json import logging import os from typing import Dict, List, Optional, Tuple @@ -137,11 +138,17 @@ def get_atomic_energies(E0s, train_collection, z_table, theories) -> dict: f"Could not compute average E0s if no training xyz given, error {e} occured" ) from e else: - try: - atomic_energies_dict = ast.literal_eval(E0s) - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError(f"E0s specified invalidly, error {e} occured") from e + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + atomic_energies_dict = json.load(open(E0s, "r")) + else: + try: + atomic_energies_dict = ast.literal_eval(E0s) + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e else: raise RuntimeError( "E0s not found in training file and not specified in command line" @@ -447,3 +454,12 @@ def create_error_table( ] ) return table + + +def check_folder_subfolder(folder_path): + entries = os.listdir(folder_path) + for entry in entries: + full_path = os.path.join(folder_path, entry) + if os.path.isdir(full_path): + return True + return False From 0571b923e2326b76b96d3dfd5793b931343310b7 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 1 Apr 2024 14:38:13 +0100 Subject: [PATCH 015/101] add level of theories in test --- mace/cli/preprocess_data.py | 2 +- mace/data/utils.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 8af61f6f..c9fb4142 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -124,7 +124,7 @@ def main(): os.makedirs(args.h5_prefix + sub_dir) # Data preparation - collections, atomic_energies_dict = get_dataset_from_xyz( + collections, atomic_energies_dict, _ = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, diff --git a/mace/data/utils.py b/mace/data/utils.py index 6849d989..c9a80cb4 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -187,11 +187,12 @@ def test_config_types( test_by_ct = [] all_cts = [] for conf in test_configs: - if conf.config_type not in all_cts: - all_cts.append(conf.config_type) - test_by_ct.append((conf.config_type, [conf])) + config_type_name = conf.config_type + "_" + conf.theory + if config_type_name not in all_cts: + all_cts.append(config_type_name) + test_by_ct.append((config_type_name, [conf])) else: - ind = all_cts.index(conf.config_type) + ind = all_cts.index(config_type_name) test_by_ct[ind][1].append(conf) return test_by_ct From d03d4018dc16abf21015df201f2557835e4e8d8a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 1 Apr 2024 21:17:52 +0100 Subject: [PATCH 016/101] fix E0 extraction --- mace/tools/scripts_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 1234692e..88c2e787 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -230,11 +230,12 @@ def dict_to_array(data): unique_keys = set() for inner_dict in data.values(): unique_keys.update(inner_dict.keys()) - sorted_keys = sorted(unique_keys) + unique_keys = list(unique_keys) + sorted_keys = sorted([int(key) for key in unique_keys]) result_array = np.zeros((len(data), len(sorted_keys))) for default_index, (_, inner_dict) in enumerate(data.items()): for key, value in inner_dict.items(): - key_index = sorted_keys.index(key) + key_index = sorted_keys.index(int(key)) result_array[default_index][key_index] = value return np.squeeze(result_array) From eb3e8f44d14abf2a9636e89cac416456f5bd71bd Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 2 Apr 2024 13:40:26 +0100 Subject: [PATCH 017/101] add options to finetuning selection --- mace/cli/fine_tuning_select.py | 68 ++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 61163fa4..23aca7b6 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -39,7 +39,9 @@ def parse_args() -> argparse.Namespace: required=False, default=None, ) - parser.add_argument("--model", help="path to model", required=True) + parser.add_argument( + "--model", help="path to model", default="small", required=False + ) parser.add_argument("--output", help="output path", required=True) parser.add_argument( "--descriptors", help="path to descriptors", required=False, default=None @@ -76,6 +78,25 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--filtering_type", + help="filtering type", + type=str, + choices=[None, "combinations", "exclusive", "inclusive"], + default=None, + ) + parser.add_argument( + "--weight_ft", + help="weight for the finetuning set", + type=float, + default=1.0, + ) + parser.add_argument( + "--weight_pt", + help="weight for the pretraining set", + type=float, + default=1.0, + ) return parser.parse_args() @@ -202,32 +223,38 @@ def main(): model_paths=args.model, device=args.device, default_dtype=args.default_dtype ) atoms_list_ft = ase.io.read(args.configs_ft, index=":") - all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) - print( - "Filtering configurations based on the finetuning set," - f"filtering type: combinations, elements: {all_species_ft}" - ) - - if args.descriptors is not None: - print("Loading descriptors") - descriptors = np.load(args.descriptors) - atoms_list_pt = ase.io.read(args.configs_pt, index=":") - for i, atoms in enumerate(atoms_list_pt): - atoms.arrays["mace_descriptors"] = descriptors[i] + if args.filtering_type != None: + all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) print( "Filtering configurations based on the finetuning set," f"filtering type: combinations, elements: {all_species_ft}" ) - atoms_list_pt = [ - x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") - ] + if args.descriptors is not None: + print("Loading descriptors") + descriptors = np.load(args.descriptors) + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + for i, atoms in enumerate(atoms_list_pt): + atoms.arrays["mace_descriptors"] = descriptors[i] + print( + "Filtering configurations based on the finetuning set," + f"filtering type: combinations, elements: {all_species_ft}" + ) + atoms_list_pt = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") - atoms_list_pt = [ - x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") - ] if args.num_samples is not None and args.num_samples < len(atoms_list_pt): if args.descriptors is None: @@ -241,6 +268,7 @@ def main(): for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True + atoms.info["config_weight"] = args.weight_pt if args.theory_pt is not None: atoms.info["theory"] = args.theory_pt @@ -248,6 +276,8 @@ def main(): ase.io.write(args.output, atoms_list_pt, format="extxyz") print("Saving a combined XYZ file") for atoms in atoms_list_ft: + atoms.info["pretrained"] = False + atoms.info["config_weight"] = args.weight_ft if args.theory_ft is not None: atoms.info["theory"] = args.theory_ft atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft From 0bc7af396c6707fba863771a9186352673ae30e4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:53:11 -0500 Subject: [PATCH 018/101] fpsample --- .gitignore | 1 + mace/cli/fine_tuning_select.py | 89 ++++++++++++++++++---------------- 2 files changed, 47 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index 7bf13ea2..5674da75 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ dist/ *.model *.pt /wandb +*.xyz diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 23aca7b6..1b76cb47 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -18,6 +18,7 @@ from mace import data import pandas as pd from mace.tools import torch_geometric, torch_tools, utils +import fpsample def parse_args() -> argparse.Namespace: @@ -157,61 +158,45 @@ class FPS: def __init__(self, atoms_list: list[ase.Atoms], n_samples: int): self.n_samples = n_samples self.atoms_list = atoms_list + self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) + self.species_dict = {x: i for i, x in enumerate(self.species)} # start from a random configuration self.list_index = [np.random.randint(0, len(atoms_list))] + self.assemble_descriptors() - def run(self) -> list[int]: + def run( + self, + ) -> list[int]: """ Run the farthest point sampling algorithm. """ - for _ in range(min(self.n_samples, len(self.atoms_list)) - 1): - self.update() + print(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) + print("n_samples", self.n_samples) + self.list_index = fpsample.fps_npdu_kdtree_sampling( + self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples + ) return self.list_index - def update(self) -> list[int]: - """ - Compute the farthest point sampling for the index-th configuration. + def assemble_descriptors(self) -> np.ndarray: """ - distance_matrix = self.compute_distance(self.list_index[-1]) - index_next = np.argmax(np.mean(distance_matrix, axis=1)) - self.list_index.append(index_next) - - def compute_distance(self, index: int) -> np.ndarray: + Assemble the descriptors for all the configurations. """ - Compute the distance matrix between the descriptor of the index-th configuration and all the other configurations. - """ - descriptors_filtered = self.filter_species(index) - # compute the distance matrix - distance_matrix = np.zeros((len(self.atoms_list), len(descriptors_filtered))) - descriptors_atoms_index = self.atoms_list[index].info["mace_descriptors"] - for zi, z in enumerate(descriptors_filtered): - distance_matrix[:, zi] = np.nan_to_num( - np.linalg.norm( - descriptors_filtered[z] - descriptors_atoms_index[z], - axis=1, + self.descriptors_dataset = np.float32( + 10e10 + * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), ) ) - # put inf to zeros - return distance_matrix - - def filter_species(self, index: int) -> list[ase.Atoms]: - """ - Filter the configurations based on the species of the index-th configuration. - """ - species_index = np.unique(self.atoms_list[index].symbols) - descriptors_species = {z: [] for z in species_index} - descriptors_index = self.atoms_list[index].info["mace_descriptors"] + ) for i, atoms in enumerate(self.atoms_list): - descriptors_atoms = atoms.info["mace_descriptors"] - for z in species_index: - descriptors_species[z].append( - descriptors_atoms[z] - if z in descriptors_atoms - else np.full_like(descriptors_index[z], np.nan) + descriptors = atoms.info["mace_descriptors"] + for z in descriptors: + self.descriptors_dataset[i, self.species_dict[z]] = np.float32( + descriptors[z] ) - for z in species_index: - descriptors_species[z] = np.array(descriptors_species[z]) - return descriptors_species def main(): @@ -232,10 +217,10 @@ def main(): ) if args.descriptors is not None: print("Loading descriptors") - descriptors = np.load(args.descriptors) + descriptors = np.load(args.descriptors, allow_pickle=True) atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): - atoms.arrays["mace_descriptors"] = descriptors[i] + atoms.info["mace_descriptors"] = descriptors[i] print( "Filtering configurations based on the finetuning set," f"filtering type: combinations, elements: {all_species_ft}" @@ -255,11 +240,29 @@ def main(): ] else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") + if args.descriptors is not None: + print( + "Loading descriptors for the pretraining set from {}".format( + args.descriptors + ) + ) + descriptors = np.load(args.descriptors, allow_pickle=True) + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] if args.num_samples is not None and args.num_samples < len(atoms_list_pt): if args.descriptors is None: print("Calculating descriptors for the pretraining set") calculate_descriptors(atoms_list_pt, calc, None) + descriptors_list = [ + atoms.info["mace_descriptors"] for atoms in atoms_list_pt + ] + print( + "Saving descriptors at {}".format( + args.output.replace(".xyz", "descriptors.npy") + ) + ) + np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list) print("Selecting configurations using Farthest Point Sampling") fps_pt = FPS(atoms_list_pt, args.num_samples) idx_pt = fps_pt.run() From 9b732ee9778a14d85d9cfc75f354f8540a5b022c Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:38:11 +0100 Subject: [PATCH 019/101] fixed tests --- mace/calculators/mace.py | 6 +- mace/data/utils.py | 3 + mace/modules/blocks.py | 4 +- mace/tools/finetuning_utils.py | 53 +++++++-- mace/tools/utils.py | 208 --------------------------------- 5 files changed, 55 insertions(+), 219 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 737a2a7a..da11a7db 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -233,7 +233,11 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = next(iter(data_loader)).to(self.device) - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"]) + node_theories = batch["theory"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_theories + ] compute_stress = not self.use_compile else: compute_stress = False diff --git a/mace/data/utils.py b/mace/data/utils.py index c0de7a9e..702e15d9 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -322,6 +322,7 @@ def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges + grp["theory"] = data.theory def save_AtomicData_to_HDF5(data, i, h5_file) -> None: @@ -344,6 +345,7 @@ def save_AtomicData_to_HDF5(data, i, h5_file) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges + grp["theory"] = data.theory def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: @@ -357,6 +359,7 @@ def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> N subgroup["forces"] = write_value(config.forces) subgroup["stress"] = write_value(config.stress) subgroup["virials"] = write_value(config.virials) + subgroup["theory"] = write_value(config.theory) subgroup["dipole"] = write_value(config.dipole) subgroup["charges"] = write_value(config.charges) subgroup["cell"] = write_value(config.cell) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 62be7e95..4b38ea1f 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -764,11 +764,11 @@ def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( "scale", - torch.atleast_1d(torch.tensor(scale, dtype=torch.get_default_dtype())), + torch.tensor(scale, dtype=torch.get_default_dtype()), ) self.register_buffer( "shift", - torch.atleast_1d(torch.tensor(shift, dtype=torch.get_default_dtype())), + torch.tensor(shift, dtype=torch.get_default_dtype()), ) def forward(self, x: torch.Tensor, theory: torch.Tensor) -> torch.Tensor: diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 9d264de5..13aa6c64 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -69,12 +69,18 @@ def load_foundations( use_shift=False, use_scale=True, max_L=2, + max_ell=3, ): """ Load the foundations of a model into a model for fine-tuning. """ assert model_foundations.r_max == model.r_max z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + try: + foundations_theories = model_foundations.theories + except AttributeError: + foundations_theories = ["Default"] + model_theories = model.theories new_z_table = table num_species_foundations = len(z_table.zs) num_channels_foundation = ( @@ -84,7 +90,6 @@ def load_foundations( indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] num_radial = model.radial_embedding.out_dim num_species = len(indices_weights) - max_ell = model.spherical_harmonics._lmax model.node_embedding.linear.weight = torch.nn.Parameter( model_foundations.node_embedding.linear.weight.view( num_species_foundations, -1 @@ -97,7 +102,6 @@ def load_foundations( model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() ) - for i in range(int(model.num_interactions)): model.interactions[i].linear_up.weight = torch.nn.Parameter( model_foundations.interactions[i].linear_up.weight.clone() @@ -159,6 +163,7 @@ def load_foundations( .clone() / (num_species_foundations / num_species) ** 0.5 ) + # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 @@ -188,20 +193,52 @@ def load_foundations( if load_readout: # Transferring readouts + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_theories)) + .flatten() + .clone() + ) model.readouts[0].linear.weight = torch.nn.Parameter( - model_foundations.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight ) + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_theories)) + .flatten() + .clone() + ) model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight + ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_theories), len(model_theories) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight ) if model_foundations.scale_shift is not None: if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.clone() + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_theories) + ).clone() if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.clone() + shift_shape = model_foundations.scale_shift.shift.shape + model.scale_shift.shift[: len(shift_shape)] = ( + model_foundations.scale_shift.shift.clone() + ) return model diff --git a/mace/tools/utils.py b/mace/tools/utils.py index cd706dd7..c33b7b3b 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -151,211 +151,3 @@ def log(self, d: Dict[str, Any]) -> None: with open(self.path, mode="a", encoding="utf-8") as f: f.write(json.dumps(d, cls=UniversalEncoder)) f.write("\n") - - -def load_foundations( - model: torch.nn.Module, - model_foundations: torch.nn.Module, - table: AtomicNumberTable, - load_readout=False, - use_shift=False, - use_scale=True, - max_L=2, - max_ell=3, -): - """ - Load the foundations of a model into a model for fine-tuning. - """ - assert model_foundations.r_max == model.r_max - z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - try: - foundations_theories = model_foundations.theories - except AttributeError: - foundations_theories = ["Default"] - model_theories = model.theories - new_z_table = table - num_species_foundations = len(z_table.zs) - num_channels_foundation = ( - model_foundations.node_embedding.linear.weight.shape[0] - // num_species_foundations - ) - indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] - num_radial = model.radial_embedding.out_dim - num_species = len(indices_weights) - model.node_embedding.linear.weight = torch.nn.Parameter( - model_foundations.node_embedding.linear.weight.view( - num_species_foundations, -1 - )[indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": - model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( - model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() - ) - for i in range(int(model.num_interactions)): - model.interactions[i].linear_up.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear_up.weight.clone() - ) - model.interactions[i].avg_num_neighbors = model_foundations.interactions[ - i - ].avg_num_neighbors - for j in range(4): # Assuming 4 layers in conv_tp_weights, - layer_name = f"layer{j}" - if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() - ) - ) - else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) - ) - - model.interactions[i].linear.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear.weight.clone() - ) - if ( - model.interactions[i].__class__.__name__ - == "RealAgnosticResidualInteractionBlock" - ): - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - num_species_foundations, - num_channels_foundation, - )[:, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - else: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - (max_ell + 1), - num_species_foundations, - num_channels_foundation, - )[:, :, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - - # Transferring products - for i in range(2): # Assuming 2 products modules - max_range = max_L + 1 if i == 0 else 1 - for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) - ) - - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) - ) - - model.products[i].linear.weight = torch.nn.Parameter( - model_foundations.products[i].linear.weight.clone() - ) - - if load_readout: - # Transferring readouts - model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() - # model_readouts_zero_linear_weight.view(num_channels_foundation, -1)[ - # :, : len(foundations_theories) - # ] = ( - # model_foundations.readouts[0] - # .linear.weight.view(num_channels_foundation, -1) - # .clone() - # ) - model_readouts_zero_linear_weight = ( - model_foundations.readouts[0] - .linear.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_theories)) - .flatten() - .clone() - ) - model.readouts[0].linear.weight = torch.nn.Parameter( - model_readouts_zero_linear_weight - ) - - shape_input_1 = ( - model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - ) - shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() - # model_readouts_one_linear_1_weight.view(num_channels_foundation, -1)[ - # :, :shape_input_1 - # ] = ( - # model_foundations.readouts[1] - # .linear_1.weight.view(num_channels_foundation, -1) - # .clone() - # ) - model_readouts_one_linear_1_weight = ( - model_foundations.readouts[1] - .linear_1.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_theories)) - .flatten() - .clone() - ) - model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_readouts_one_linear_1_weight - ) - # shape_input_2 = ( - # model_foundations.readouts[1].linear_2.__dict__["irreps_out"].num_irreps - # ) - - model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() - # model_readouts_one_linear_2_weight.view(shape_output_1, -1)[ - # :shape_input_1, :shape_input_2 - # ] = model_foundations.readouts[1].linear_2.weight.view( - # shape_input_1, -1 - # ).clone() / ( - # ((shape_input_1) / (shape_output_1)) ** 0.5 - # ) - model_readouts_one_linear_2_weight = model_foundations.readouts[ - 1 - ].linear_2.weight.view(shape_input_1, -1).repeat( - len(model_theories), len(model_theories) - ).flatten().clone() / ( - ((shape_input_1) / (shape_output_1)) ** 0.5 - ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_readouts_one_linear_2_weight - ) - if model_foundations.scale_shift is not None: - if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( - len(model_theories) - ).clone() - if use_shift: - shift_shape = model_foundations.scale_shift.shift.shape - model.scale_shift.shift[: len(shift_shape)] = ( - model_foundations.scale_shift.shift.clone() - ) - return model From f5ab9c5d45e09ac49c7874c44a14ba83c813d5ab Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:46:06 +0100 Subject: [PATCH 020/101] add option to not filter elements --- .../foundations_models/mp_vasp_e0.json | 91 +++++++++++++++++++ mace/cli/run_train.py | 28 ++++-- mace/tools/__init__.py | 4 +- mace/tools/arg_parser.py | 7 ++ mace/tools/finetuning_utils.py | 17 +++- tests/test_foundations.py | 9 +- 6 files changed, 138 insertions(+), 18 deletions(-) create mode 100644 mace/calculators/foundations_models/mp_vasp_e0.json diff --git a/mace/calculators/foundations_models/mp_vasp_e0.json b/mace/calculators/foundations_models/mp_vasp_e0.json new file mode 100644 index 00000000..01771879 --- /dev/null +++ b/mace/calculators/foundations_models/mp_vasp_e0.json @@ -0,0 +1,91 @@ +{ + "pbe": { + "1": -1.11734008, + "2": 0.00096759, + "3": -0.29754725, + "4": -0.01781697, + "5": -0.26885011, + "6": -1.26173507, + "7": -3.12438806, + "8": -1.54838784, + "9": -0.51882044, + "10": -0.01241601, + "11": -0.22883163, + "12": -0.00951015, + "13": -0.21630193, + "14": -0.8263903, + "15": -1.88816619, + "16": -0.89160769, + "17": -0.25828273, + "18": -0.04925973, + "19": -0.22697913, + "20": -0.0927795, + "21": -2.11396364, + "22": -2.50054871, + "23": -3.70477179, + "24": -5.60261985, + "25": -5.32541181, + "26": -3.52004933, + "27": -1.93555024, + "28": -0.9351969, + "29": -0.60025846, + "30": -0.1651332, + "31": -0.32990651, + "32": -0.77971828, + "33": -1.68367812, + "34": -0.76941032, + "35": -0.22213843, + "36": -0.0335879, + "37": -0.1881724, + "38": -0.06826294, + "39": -2.17084228, + "40": -2.28579303, + "41": -3.13429753, + "42": -4.60211419, + "43": -3.45201492, + "44": -2.38073513, + "45": -1.46855515, + "46": -1.4773126, + "47": -0.33954585, + "48": -0.16843877, + "49": -0.35470981, + "50": -0.83642657, + "51": -1.41101987, + "52": -0.65740879, + "53": -0.18964571, + "54": -0.00857582, + "55": -0.13771876, + "56": -0.03457659, + "57": -0.45580806, + "58": -1.3309175, + "59": -0.29671824, + "60": -0.30391193, + "61": -0.30898427, + "62": -0.25470891, + "63": -8.38001538, + "64": -10.38896525, + "65": -0.3059505, + "66": -0.30676216, + "67": -0.30874667, + "69": -0.25190039, + "70": -0.06431414, + "71": -0.31997586, + "72": -3.52770927, + "73": -3.54492209, + "75": -4.70108713, + "76": -2.88257209, + "77": -1.46779304, + "78": -0.50269936, + "79": -0.28801193, + "80": -0.12454674, + "81": -0.31737194, + "82": -0.77644932, + "83": -1.32627283, + "89": -0.26827152, + "90": -0.90817426, + "91": -2.47653193, + "92": -4.90438537, + "93": -7.63378961, + "94": -10.77237713 + } +} \ No newline at end of file diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index ec91f743..a64af37e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -35,7 +35,10 @@ check_folder_subfolder, ) from mace.tools.slurm_distributed import DistributedEnvironment -from mace.tools.finetuning_utils import load_foundations, extract_config_mace_model +from mace.tools.finetuning_utils import ( + load_foundations_elements, + extract_config_mace_model, +) def main() -> None: @@ -552,13 +555,22 @@ def main() -> None: raise RuntimeError(f"Unknown model: '{args.model}'") if args.foundation_model is not None: - model = load_foundations( - model, - model_foundation, - z_table, - load_readout=True, - max_L=args.max_L, - ) + if args.foundation_filter_elements: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=True, + max_L=args.max_L, + ) + else: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=False, + max_L=args.max_L, + ) model.to(device) # Optimizer diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 64dc08cc..5f851483 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -31,7 +31,7 @@ setup_logger, ) -from .finetuning_utils import load_foundations, extract_load +from .finetuning_utils import load_foundations_elements, extract_load __all__ = [ "TensorDict", @@ -66,7 +66,7 @@ "cartesian_to_spherical", "voigt_to_matrix", "init_wandb", - "load_foundations", + "load_foundations_elements", "extract_load", "build_preprocess_arg_parser", ] diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index bd5a9bd5..9afa4b12 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -324,6 +324,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=None, required=False, ) + parser.add_argument( + "--foundation_filter_elements", + help="Filter element during fine-tuning", + type=bool, + default=True, + required=False, + ) parser.add_argument( "--theories", help="List of theories in the training set", diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 13aa6c64..f148c275 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -61,7 +61,7 @@ def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: return model_copy.to(map_location) -def load_foundations( +def load_foundations_elements( model: torch.nn.Module, model_foundations: torch.nn.Module, table: AtomicNumberTable, @@ -76,10 +76,6 @@ def load_foundations( """ assert model_foundations.r_max == model.r_max z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - try: - foundations_theories = model_foundations.theories - except AttributeError: - foundations_theories = ["Default"] model_theories = model.theories new_z_table = table num_species_foundations = len(z_table.zs) @@ -242,3 +238,14 @@ def load_foundations( model_foundations.scale_shift.shift.clone() ) return model + + +def load_foundations( + model, + model_foundations, +): + for name, param in model_foundations.named_parameters(): + if name in model.state_dict().keys(): + if "readouts" not in name: + model.state_dict()[name].copy_(param) + return model diff --git a/tests/test_foundations.py b/tests/test_foundations.py index f8d296f2..22f0f2f8 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -12,7 +12,10 @@ from mace.tools.utils import ( AtomicNumberTable, ) -from mace.tools.finetuning_utils import load_foundations, extract_config_mace_model +from mace.tools.finetuning_utils import ( + load_foundations_elements, + extract_config_mace_model, +) torch.set_default_dtype(torch.float64) config = data.Configuration( @@ -72,7 +75,7 @@ def test_foundations(): default_dtype="float64", ) model_foundations = calc.models[0] - model_loaded = load_foundations( + model_loaded = load_foundations_elements( model, model_foundations, table=table, @@ -138,7 +141,7 @@ def test_multi_reference(): ) model = modules.ScaleShiftMACE(**model_config) calc_foundation = mace_mp(device="cpu", default_dtype="float64") - model_loaded = load_foundations( + model_loaded = load_foundations_elements( model, calc_foundation.models[0], table=table, From f552d6c89b5f19365706580c12d2dadf4a6a6ed4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 18 Apr 2024 18:35:02 +0100 Subject: [PATCH 021/101] fix scale shift for foundation model --- mace/cli/run_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index a64af37e..130978ea 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -430,10 +430,14 @@ def main() -> None: model_config["num_elements"] = len(z_table) args.max_L = model_config["hidden_irreps"].lmax if args.model == "MACE" and calc.models[0].__class__.__name__ == "MACE": - model_config["atomic_inter_shift"] = 0.0 + model_config["atomic_inter_shift"] = [0.0] * len(theories) else: - model_config["atomic_inter_shift"] = args.mean + model_config["atomic_inter_shift"] = [args.mean] * len(theories) + model_config["atomic_inter_scale"] = [model_config["atomic_inter_scale"]] * len( + theories + ) args.model = "FoundationMACE" + model_config["theories"] = args.theories else: logging.info("Building model") if args.num_channels is not None and args.max_L is not None: From c92876971218c942b83cc3d220bb7688c1f51d9c Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 22 Apr 2024 11:51:40 +0100 Subject: [PATCH 022/101] add mutlihead test, fix scale shift --- mace/cli/run_train.py | 10 ++-- mace/modules/blocks.py | 4 +- mace/tools/finetuning_utils.py | 9 ++-- mace/tools/scripts_utils.py | 13 ++++-- tests/test_run_train.py | 84 ++++++++++++++++++++++++++++++++++ 5 files changed, 105 insertions(+), 15 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 130978ea..ad3b625f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -108,7 +108,9 @@ def main() -> None: return_raw_model=True, ) else: - model_foundation = torch.load(args.foundation_model, map_location=device) + model_foundation = torch.load( + args.foundation_model, map_location=args.device + ) logging.info( f"Using foundation model {args.foundation_model} as initial checkpoint." ) @@ -299,7 +301,6 @@ def main() -> None: seed=args.seed, ) valid_samplers[theory] = valid_sampler - train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, batch_size=args.batch_size, @@ -433,9 +434,7 @@ def main() -> None: model_config["atomic_inter_shift"] = [0.0] * len(theories) else: model_config["atomic_inter_shift"] = [args.mean] * len(theories) - model_config["atomic_inter_scale"] = [model_config["atomic_inter_scale"]] * len( - theories - ) + model_config["atomic_inter_scale"] = [1.0] * len(theories) args.model = "FoundationMACE" model_config["theories"] = args.theories else: @@ -576,6 +575,7 @@ def main() -> None: max_L=args.max_L, ) model.to(device) + logging.info(model) # Optimizer decay_interactions = {} diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 4b38ea1f..62be7e95 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -764,11 +764,11 @@ def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( "scale", - torch.tensor(scale, dtype=torch.get_default_dtype()), + torch.atleast_1d(torch.tensor(scale, dtype=torch.get_default_dtype())), ) self.register_buffer( "shift", - torch.tensor(shift, dtype=torch.get_default_dtype()), + torch.atleast_1d(torch.tensor(shift, dtype=torch.get_default_dtype())), ) def forward(self, x: torch.Tensor, theory: torch.Tensor) -> torch.Tensor: diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index f148c275..8e25accc 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -66,7 +66,7 @@ def load_foundations_elements( model_foundations: torch.nn.Module, table: AtomicNumberTable, load_readout=False, - use_shift=False, + use_shift=True, use_scale=True, max_L=2, max_ell=3, @@ -233,10 +233,9 @@ def load_foundations_elements( len(model_theories) ).clone() if use_shift: - shift_shape = model_foundations.scale_shift.shift.shape - model.scale_shift.shift[: len(shift_shape)] = ( - model_foundations.scale_shift.shift.clone() - ) + model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( + len(model_theories) + ).clone() return model diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 1acd454e..092edef5 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -78,9 +78,16 @@ def get_dataset_from_xyz( logging.info( "Using random %s%% of training set for validation", 100 * valid_fraction ) - train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed - ) + train_configs, valid_configs = [], [] + for theory in theories: + all_train_configs_theory = [ + config for config in all_train_configs if config.theory == theory + ] + train_configs_theory, valid_configs_theory = data.random_train_valid_split( + all_train_configs_theory, valid_fraction, seed + ) + train_configs.extend(train_configs_theory) + valid_configs.extend(valid_configs_theory) test_configs = [] if test_path is not None: diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 3343c2be..376739bd 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -349,3 +349,87 @@ def test_run_train_foundation(tmp_path, fitting_configs): 0.5659750699996948, ] assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_ = [] + for i, c in enumerate(fitting_configs): + if i % 2 == 0: + c.info["theory"] = "DFT" + else: + c.info["theory"] = "MP2" + fitting_configs_.append(c) + ase.io.write(tmp_path / "fit.xyz", fitting_configs_) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float32" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["theories"] = "['MP2','DFT']" + mace_params["batch_size"] = 2 + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + tmp_path / "MACE.model", device="cpu", default_dtype="float32" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 + ref_Es = [ + 1.1737573146820068, + 0.37266889214515686, + 0.3591262996196747, + 0.1222146600484848, + 0.21925662457942963, + 0.30689263343811035, + 0.23039104044437408, + 0.11772646009922028, + 0.2409999519586563, + 0.04042769968509674, + 0.6277227997779846, + 0.13879507780075073, + 0.18997330963611603, + 0.30589431524276733, + 0.34129756689071655, + -0.0034095346927642822, + 0.5614650249481201, + 0.29983872175216675, + 0.3369189500808716, + -0.20579558610916138, + 0.1669044941663742, + 0.119053915143013, + ] + assert np.allclose(Es, ref_Es) From 5a66ca442391a73e81bf4cadcfec22d6d31c3c96 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 22 Apr 2024 16:34:24 +0100 Subject: [PATCH 023/101] fix logging --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index ad3b625f..7f94d70f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -323,6 +323,7 @@ def main() -> None: drop_last=False, pin_memory=args.pin_memory, num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), ) # valid_loader = torch_geometric.dataloader.DataLoader( # dataset=valid_set, @@ -575,7 +576,6 @@ def main() -> None: max_L=args.max_L, ) model.to(device) - logging.info(model) # Optimizer decay_interactions = {} From 81ad801a71072969d92b227800a135b2791cbfc7 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Tue, 23 Apr 2024 16:48:28 -0400 Subject: [PATCH 024/101] Fix logic when testing for config_type=IsolatedAtom and/or len(atoms) == 1 --- mace/data/utils.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 702e15d9..d8b2c8bc 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -220,27 +220,26 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - if len(atoms) == 1: - isolated_atom_config = atoms.info.get("config_type") == "IsolatedAtom" - if isolated_atom_config: - if energy_key in atoms.info.keys(): - theory = atoms.info.get(theory_key, "Default") - if theory not in atomic_energies_dict: - atomic_energies_dict[theory] = {} - atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( - atoms.info[energy_key] - ) - else: - logging.warning( - f"Configuration '{idx}' is marked as 'IsolatedAtom' " - "but does not contain an energy. Zero energy will be used." - ) - theory = atoms.info.get(theory_key, "Default") - if theory not in atomic_energies_dict: - atomic_energies_dict[theory] = {} - atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( - np.zeros(1) - ) + if atoms.info.get("config_type") == "IsolatedAtom": + assert len(atoms) == 1, f"Got config_type=IsolatedAtom for a config with len {len(atoms)}" + if energy_key in atoms.info.keys(): + theory = atoms.info.get(theory_key, "Default") + if theory not in atomic_energies_dict: + atomic_energies_dict[theory] = {} + atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( + atoms.info[energy_key] + ) + else: + logging.warning( + f"Configuration '{idx}' is marked as 'IsolatedAtom' " + "but does not contain an energy. Zero energy will be used." + ) + theory = atoms.info.get(theory_key, "Default") + if theory not in atomic_energies_dict: + atomic_energies_dict[theory] = {} + atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( + np.zeros(1) + ) else: atoms_without_iso_atoms.append(atoms) From 00cc869f8315da6a4d80ef6a360f018f2dfba2c3 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 24 Apr 2024 09:48:25 +0100 Subject: [PATCH 025/101] fix calc none check --- .gitignore | 1 + mace/cli/run_train.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index b72a270a..3817d9f3 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ dist/ /wandb *.xyz /checkpoints +*.model diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7f94d70f..20be018a 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -431,7 +431,7 @@ def main() -> None: model_config["atomic_numbers"] = z_table.zs model_config["num_elements"] = len(z_table) args.max_L = model_config["hidden_irreps"].lmax - if args.model == "MACE" and calc.models[0].__class__.__name__ == "MACE": + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": model_config["atomic_inter_shift"] = [0.0] * len(theories) else: model_config["atomic_inter_shift"] = [args.mean] * len(theories) From 41d77c480eb7ee6ac937e298638d8a9447f35d8e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:43:48 +0100 Subject: [PATCH 026/101] change interface for automatic multihead finetuning --- .github/workflows/lint.yml | 46 +++++++++ mace/calculators/mace.py | 12 +-- mace/cli/fine_tuning_select.py | 32 +++---- mace/cli/run_train.py | 166 +++++++++++++++++++++++---------- mace/data/atomic_data.py | 16 ++-- mace/data/hdf5_dataset.py | 4 +- mace/data/utils.py | 108 ++++++++++++++------- mace/modules/blocks.py | 27 +++--- mace/modules/irreps_tools.py | 10 +- mace/modules/models.py | 48 +++++----- mace/modules/utils.py | 58 ++++++------ mace/tools/arg_parser.py | 10 +- mace/tools/finetuning_utils.py | 12 +-- mace/tools/scripts_utils.py | 31 +++--- mace/tools/train.py | 26 +++--- tests/test_foundations.py | 6 +- tests/test_models.py | 10 +- tests/test_run_train.py | 8 +- 18 files changed, 396 insertions(+), 234 deletions(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..998ae250 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,46 @@ +name: Linting and code formatting + +on: [] + # Trigger the workflow on push or pull request, + # but only for the main branch + # push: + # branches: [] + # pull_request: + # branches: [] + + +jobs: + build-linux: + runs-on: ubuntu-latest + + steps: + # Setup + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8.10 + - name: Get cache + uses: actions/cache@v2 + with: + path: /opt/hostedtoolcache/Python/3.8.10/x64/lib/python3.8/site-packages + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + + # Install packages + - name: Install packages required for installation + run: python -m pip install --upgrade pip setuptools wheel + - name: Install dependencies + run: pip install -r requirements.txt + + # Check code + - name: Check formatting with yapf + run: python -m yapf --style=.style.yapf --diff --recursive . +# - name: Lint with flake8 +# run: flake8 --config=.flake8 . +# - name: Check type annotations with mypy +# run: mypy --config-file=.mypy.ini . + + - name: Test with pytest + run: python -m pytest tests diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index da11a7db..b33bd5c5 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -145,9 +145,9 @@ def __init__( ) self.charges_key = charges_key try: - self.theories = self.models[0].theories + self.heads = self.models[0].heads except: - self.theories = ["Default"] + self.heads = ["Default"] model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -223,7 +223,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): config, z_table=self.z_table, cutoff=self.r_max, - theories=self.theories, + heads=self.heads, ) ], batch_size=1, @@ -233,10 +233,10 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = next(iter(data_loader)).to(self.device) - node_theories = batch["theory"][batch["batch"]] + node_heads = batch["head"][batch["batch"]] num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - num_atoms_arange, node_theories + num_atoms_arange, node_heads ] compute_stress = not self.use_compile else: @@ -339,7 +339,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): config, z_table=self.z_table, cutoff=self.r_max, - theories=self.theories, + heads=self.heads, ) ], batch_size=1, diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 1b76cb47..b90b432c 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -62,20 +62,14 @@ def parse_args() -> argparse.Namespace: default="float64", ) parser.add_argument( - "--info_prefix", - help="prefix for energy, forces and stress keys", - type=str, - default="MACE_", - ) - parser.add_argument( - "--theory_pt", - help="level of theory for the pretraining set", + "--head_pt", + help="level of head for the pretraining set", type=str, default=None, ) parser.add_argument( - "--theory_ft", - help="level of theory for the finetuning set", + "--head_ft", + help="level of head for the finetuning set", type=str, default=None, ) @@ -199,8 +193,9 @@ def assemble_descriptors(self) -> np.ndarray: ) -def main(): - args = parse_args() +def select_samples( + args: argparse.Namespace, +) -> None: if args.model in ["small", "medium", "large"]: calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) else: @@ -272,8 +267,8 @@ def main(): # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True atoms.info["config_weight"] = args.weight_pt - if args.theory_pt is not None: - atoms.info["theory"] = args.theory_pt + if args.head_pt is not None: + atoms.info["head"] = args.head_pt print("Saving the selected configurations") ase.io.write(args.output, atoms_list_pt, format="extxyz") @@ -281,13 +276,18 @@ def main(): for atoms in atoms_list_ft: atoms.info["pretrained"] = False atoms.info["config_weight"] = args.weight_ft - if args.theory_ft is not None: - atoms.info["theory"] = args.theory_ft + if args.head_ft is not None: + atoms.info["head"] = args.head_ft atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft ase.io.write( args.output.replace(".xyz", "_combined.xyz"), atoms_fps_pt_ft, format="extxyz" ) +def main(): + args = parse_args() + select_samples(args) + + if __name__ == "__main__": main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 20be018a..c39170cb 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -11,6 +11,8 @@ import os from pathlib import Path from typing import Optional +import urllib.request + import numpy as np import torch.distributed @@ -23,10 +25,12 @@ import mace from mace import data, modules, tools from mace.calculators.foundations_models import mace_mp, mace_off +from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric from mace.tools.scripts_utils import ( LRScheduler, create_error_table, + dict_to_namespace, get_atomic_energies, get_config_type_weights, get_dataset_from_xyz, @@ -141,7 +145,7 @@ def main() -> None: ".xyz" ), "valid_file if given must be same format as train_file" config_type_weights = get_config_type_weights(args.config_type_weights) - collections, atomic_energies_dict, theories = get_dataset_from_xyz( + collections, atomic_energies_dict, heads = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -156,23 +160,87 @@ def main() -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, ) - if args.theories is not None: - args.theories = ast.literal_eval(args.theories) - assert set(theories) == set(args.theories), ( - "Theories from command line and data do not match," - f"{set(theories)} != {set(args.theories)}" + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + assert set(heads) == set(args.heads), ( + "heads from command line and data do not match," + f"{set(heads)} != {set(args.heads)}" ) logging.info( - "Using theories from command line argument," - f" theories used: {args.theories}" + "Using heads from command line argument," f" heads used: {args.heads}" ) - theories = args.theories + heads = args.heads else: logging.info( - "Using theories extracted from data files," - f" theories used: {theories}" + "Using heads extracted from data files," f" heads used: {heads}" ) + if args.multiheads_finetuning: + logging.info("Using multiheads finetuning mode") + heads = list(set(["pbe_mp"] + heads)) + try: + checkpoint_url = "https://tinyurl.com/mw2wetc5" + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c + for c in os.path.basename(checkpoint_url) + if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + dataset_mp = cached_dataset_path + msg = f"Using Materials Project dataset with {dataset_mp}" + logging.info(msg) + args_samples = { + "configs_pt": dataset_mp, + "configs_ft": args.train_file, + "num_samples": 1000, + "seed": args.seed, + "model": args.foundation_model, + "head_pt": "pbe_mp", + "head_ft": "Default", + "weight_pt": 1.0, + "weight_ft": 1.0, + "filtering_type": "combination", + "output": f"{cache_dir}/mp_finetuning.xyz", + "descriptors": None, + "device": args.device, + "default_dtype": args.default_dtype, + } + select_samples(dict_to_namespace(args_samples)) + collections_mp, _, _ = get_dataset_from_xyz( + train_path=dataset_mp, + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=None, + seed=args.seed, + energy_key="energy", + forces_key="forces", + stress_key="stress", + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + ) + collections.train += collections_mp.train + collections.valid += collections_mp.valid + except Exception as exc: + raise RuntimeError( + "Model download failed and no local model found" + ) from exc + logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," @@ -204,12 +272,14 @@ def main() -> None: if atomic_energies_dict is None or len(atomic_energies_dict) == 0: if args.train_file.endswith(".xyz"): atomic_energies_dict = get_atomic_energies( - args.E0s, collections.train, z_table, theories + args.E0s, collections.train, z_table, heads ) else: - atomic_energies_dict = get_atomic_energies( - args.E0s, None, z_table, theories - ) + atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) + if args.multiheads_finetuning: + with open("mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: + E0s_mp = json.load(file) + atomic_energies_dict["pbe_mp"] = {E0s_mp["pbe"][z] for z in z_table.zs} if args.model == "AtomicDipolesMACE": atomic_energies = None @@ -239,44 +309,44 @@ def main() -> None: if args.train_file.endswith(".xyz"): train_set = [ data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, theories=theories + config, z_table=z_table, cutoff=args.r_max, heads=heads ) for config in collections.train ] - valid_sets = {theory: [] for theory in theories} - for theory in theories: - valid_sets[theory] = [ + valid_sets = {head: [] for head in heads} + for head in heads: + valid_sets[head] = [ data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, theories=theories + config, z_table=z_table, cutoff=args.r_max, heads=heads ) for config in collections.valid - if config.theory == theory + if config.head == head ] elif args.train_file.endswith(".h5"): train_set = data.HDF5Dataset( - args.train_file, r_max=args.r_max, z_table=z_table, theories=theories + args.train_file, r_max=args.r_max, z_table=z_table, heads=heads ) valid_set = data.HDF5Dataset( - args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories + args.valid_file, r_max=args.r_max, z_table=z_table, heads=heads ) else: # This case would be for when the file path is to a directory of multiple .h5 files train_set = data.dataset_from_sharded_hdf5( - args.train_file, r_max=args.r_max, z_table=z_table, theories=theories + args.train_file, r_max=args.r_max, z_table=z_table, heads=heads ) - # check if the folder has subfolders for each theory by opening args.valid_file folder + # check if the folder has subfolders for each head by opening args.valid_file folder if check_folder_subfolder(args.valid_file): valid_sets = {} - for theory in theories: - valid_sets[theory] = data.dataset_from_sharded_hdf5( - os.path.join(args.valid_file, theory), + for head in heads: + valid_sets[head] = data.dataset_from_sharded_hdf5( + os.path.join(args.valid_file, head), r_max=args.r_max, z_table=z_table, - theories=theories, + heads=heads, ) else: valid_set = data.dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table, theories=theories + args.valid_file, r_max=args.r_max, z_table=z_table, heads=heads ) valid_sets = {"Default": valid_set} @@ -291,7 +361,7 @@ def main() -> None: seed=args.seed, ) valid_samplers = {} - for theory, valid_set in valid_sets.items(): + for head, valid_set in valid_sets.items(): valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_set, num_replicas=world_size, @@ -300,7 +370,7 @@ def main() -> None: drop_last=True, seed=args.seed, ) - valid_samplers[theory] = valid_sampler + valid_samplers[head] = valid_sampler train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, batch_size=args.batch_size, @@ -311,14 +381,14 @@ def main() -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - valid_loaders = {theories[i]: None for i in range(len(theories))} + valid_loaders = {heads[i]: None for i in range(len(heads))} if not isinstance(valid_sets, dict): valid_sets = {"Default": valid_sets} - for theory, valid_set in valid_sets.items(): - valid_loaders[theory] = torch_geometric.dataloader.DataLoader( + for head, valid_set in valid_sets.items(): + valid_loaders[head] = torch_geometric.dataloader.DataLoader( dataset=valid_set, batch_size=args.valid_batch_size, - sampler=valid_samplers[theory] if args.distributed else None, + sampler=valid_samplers[head] if args.distributed else None, shuffle=(valid_sampler is None), drop_last=False, pin_memory=args.pin_memory, @@ -432,12 +502,12 @@ def main() -> None: model_config["num_elements"] = len(z_table) args.max_L = model_config["hidden_irreps"].lmax if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": - model_config["atomic_inter_shift"] = [0.0] * len(theories) + model_config["atomic_inter_shift"] = [0.0] * len(heads) else: - model_config["atomic_inter_shift"] = [args.mean] * len(theories) - model_config["atomic_inter_scale"] = [1.0] * len(theories) + model_config["atomic_inter_shift"] = [args.mean] * len(heads) + model_config["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" - model_config["theories"] = args.theories + model_config["heads"] = args.heads else: logging.info("Building model") if args.num_channels is not None and args.max_L is not None: @@ -483,10 +553,10 @@ def main() -> None: ], MLP_irreps=o3.Irreps(args.MLP_irreps), atomic_inter_scale=args.std, - atomic_inter_shift=[0.0] * len(theories), + atomic_inter_shift=[0.0] * len(heads), radial_MLP=ast.literal_eval(args.radial_MLP), radial_type=args.radial_type, - theories=theories, + heads=heads, ) elif args.model == "ScaleShiftMACE": model = modules.ScaleShiftMACE( @@ -501,7 +571,7 @@ def main() -> None: atomic_inter_shift=args.mean, radial_MLP=ast.literal_eval(args.radial_MLP), radial_type=args.radial_type, - theories=theories, + heads=heads, ) elif args.model == "FoundationMACE": model = modules.ScaleShiftMACE(**model_config) @@ -783,15 +853,15 @@ def main() -> None: all_data_loaders = { "train": train_loader, } - for theory, valid_loader in valid_loaders.items(): - all_data_loaders[theory] = valid_loader + for head, valid_loader in valid_loaders.items(): + all_data_loaders[head] = valid_loader test_sets = {} if args.train_file.endswith(".xyz"): for name, subset in collections.tests: test_sets[name] = [ data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, theories=theories + config, z_table=z_table, cutoff=args.r_max, heads=heads ) for config in subset ] @@ -800,14 +870,14 @@ def main() -> None: for test_file in test_files: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table, theories=theories + test_file, r_max=args.r_max, z_table=z_table, heads=heads ) else: test_folders = glob(args.test_dir + "/*") for folder in test_folders: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table, theories=theories + folder, r_max=args.r_max, z_table=z_table, heads=heads ) for test_name, test_set in test_sets.items(): diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 45bbc969..f157cecd 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -52,7 +52,7 @@ def __init__( unit_shifts: torch.Tensor, # [n_edges, 3] cell: Optional[torch.Tensor], # [3,3] weight: Optional[torch.Tensor], # [,] - theory: Optional[torch.Tensor], # [,] + head: Optional[torch.Tensor], # [,] energy_weight: Optional[torch.Tensor], # [,] forces_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,] @@ -73,7 +73,7 @@ def __init__( assert unit_shifts.shape[1] == 3 assert len(node_attrs.shape) == 2 assert weight is None or len(weight.shape) == 0 - assert theory is None or len(theory.shape) == 0 + assert head is None or len(head.shape) == 0 assert energy_weight is None or len(energy_weight.shape) == 0 assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 @@ -95,7 +95,7 @@ def __init__( "cell": cell, "node_attrs": node_attrs, "weight": weight, - "theory": theory, + "head": head, "energy_weight": energy_weight, "forces_weight": forces_weight, "stress_weight": stress_weight, @@ -115,7 +115,7 @@ def from_config( config: Configuration, z_table: AtomicNumberTable, cutoff: float, - theories: Optional[list] = ["Default"], + heads: Optional[list] = ["Default"], ) -> "AtomicData": edge_index, shifts, unit_shifts = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell @@ -126,10 +126,10 @@ def from_config( num_classes=len(z_table), ) try: - theory = torch.tensor(theories.index(config.theory), dtype=torch.long) + head = torch.tensor(heads.index(config.head), dtype=torch.long) except: - print(f"Theory {config.theory} not found in {theories}") - theory = torch.tensor(0, dtype=torch.long) + print(f"head {config.head} not found in {heads}") + head = torch.tensor(0, dtype=torch.long) cell = ( torch.tensor(config.cell, dtype=torch.get_default_dtype()) @@ -212,7 +212,7 @@ def from_config( cell=cell, node_attrs=one_hot, weight=weight, - theory=theory, + head=head, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 6193c409..ce3a9b83 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -58,7 +58,7 @@ def __getitem__(self, index): dipole=unpack_value(subgrp["dipole"][()]), charges=unpack_value(subgrp["charges"][()]), weight=unpack_value(subgrp["weight"][()]), - theory=unpack_value(subgrp["theory"][()]), + head=unpack_value(subgrp["head"][()]), energy_weight=unpack_value(subgrp["energy_weight"][()]), forces_weight=unpack_value(subgrp["forces_weight"][()]), stress_weight=unpack_value(subgrp["stress_weight"][()]), @@ -71,7 +71,7 @@ def __getitem__(self, index): config, z_table=self.z_table, cutoff=self.r_max, - theories=self.kwargs.get("theories", ["Default"]), + heads=self.kwargs.get("heads", ["Default"]), ) return atomic_data diff --git a/mace/data/utils.py b/mace/data/utils.py index d8b2c8bc..b12368bf 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -47,7 +47,7 @@ class Configuration: stress_weight: float = 1.0 # weight of config stress in loss virials_weight: float = 1.0 # weight of config virial in loss config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config - theory: Optional[str] = "Default" # theory used to compute the config + head: Optional[str] = "Default" # head used to compute the config Configurations = List[Configuration] @@ -79,7 +79,7 @@ def config_from_atoms_list( virials_key="virials", dipole_key="dipole", charges_key="charges", - theory_key="theory", + head_key="head", config_type_weights: Dict[str, float] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" @@ -97,7 +97,7 @@ def config_from_atoms_list( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, - theory_key=theory_key, + head_key=head_key, config_type_weights=config_type_weights, ) ) @@ -112,7 +112,7 @@ def config_from_atoms( virials_key="virials", dipole_key="dipole", charges_key="charges", - theory_key="theory", + head_key="head", config_type_weights: Dict[str, float] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" @@ -140,7 +140,7 @@ def config_from_atoms( stress_weight = atoms.info.get("config_stress_weight", 1.0) virials_weight = atoms.info.get("config_virials_weight", 1.0) - theory = atoms.info.get(theory_key, "Default") + head = atoms.info.get(head_key, "Default") # fill in missing quantities but set their weight to 0.0 if energy is None: @@ -169,7 +169,7 @@ def config_from_atoms( dipole=dipole, charges=charges, weight=weight, - theory=theory, + head=head, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, @@ -187,7 +187,7 @@ def test_config_types( test_by_ct = [] all_cts = [] for conf in test_configs: - config_type_name = conf.config_type + "_" + conf.theory + config_type_name = conf.config_type + "_" + conf.head if config_type_name not in all_cts: all_cts.append(config_type_name) test_by_ct.append((config_type_name, [conf])) @@ -206,27 +206,69 @@ def load_from_xyz( virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", - theory_key: str = "theory", + head_key: str = "head", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") + # Perform initial checks and log warnings + if energy_key == "energy": + logging.info( + "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to 'REF_energy'" + ) + energy_key = "REF_energy" + + if forces_key == "forces": + logging.info( + "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'" + ) + forces_key = "REF_forces" + + if stress_key == "stress": + logging.info( + "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'" + ) + stress_key = "REF_stress" + + # Process each atom only once + for atoms in atoms_list: + if energy_key == "REF_energy": + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None + + if forces_key == "REF_forces": + try: + atoms.info["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to extract forces: {e}") + atoms.info["REF_forces"] = None + + if stress_key == "REF_stress": + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None + if not isinstance(atoms_list, list): atoms_list = [atoms_list] - atomic_energies_dict = {} if extract_atomic_energies: atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): if atoms.info.get("config_type") == "IsolatedAtom": - assert len(atoms) == 1, f"Got config_type=IsolatedAtom for a config with len {len(atoms)}" + assert ( + len(atoms) == 1 + ), f"Got config_type=IsolatedAtom for a config with len {len(atoms)}" if energy_key in atoms.info.keys(): - theory = atoms.info.get(theory_key, "Default") - if theory not in atomic_energies_dict: - atomic_energies_dict[theory] = {} - atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( + head = atoms.info.get(head_key, "Default") + if head not in atomic_energies_dict: + atomic_energies_dict[head] = {} + atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( atoms.info[energy_key] ) else: @@ -234,10 +276,10 @@ def load_from_xyz( f"Configuration '{idx}' is marked as 'IsolatedAtom' " "but does not contain an energy. Zero energy will be used." ) - theory = atoms.info.get(theory_key, "Default") - if theory not in atomic_energies_dict: - atomic_energies_dict[theory] = {} - atomic_energies_dict[theory][atoms.get_atomic_numbers()[0]] = ( + head = atoms.info.get(head_key, "Default") + if head not in atomic_energies_dict: + atomic_energies_dict[head] = {} + atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( np.zeros(1) ) else: @@ -247,10 +289,10 @@ def load_from_xyz( logging.info("Using isolated atom energies from training file") if not keep_isolated_atoms: atoms_list = atoms_without_iso_atoms - theories = set() + heads = set() for atoms in atoms_list: - theories.add(atoms.info.get(theory_key, "Default")) - theories = list(theories) + heads.add(atoms.info.get(head_key, "Default")) + heads = list(heads) configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, @@ -260,13 +302,13 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, - theory_key=theory_key, + head_key=head_key, ) - return atomic_energies_dict, configs, theories + return atomic_energies_dict, configs, heads def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable, theories: List[str] + collections_train: Configurations, z_table: AtomicNumberTable, heads: List[str] ) -> Dict[int, float]: """ Function to compute the average interaction energy of each chemical element @@ -275,13 +317,13 @@ def compute_average_E0s( len_train = len(collections_train) len_zs = len(z_table) atomic_energies_dict = {} - for theory in theories: + for head in heads: A = np.zeros((len_train, len_zs)) B = np.zeros(len_train) - if theory not in atomic_energies_dict: - atomic_energies_dict[theory] = {} + if head not in atomic_energies_dict: + atomic_energies_dict[head] = {} for i in range(len_train): - if collections_train[i].theory != theory: + if collections_train[i].head != head: continue B[i] = collections_train[i].energy for j, z in enumerate(z_table.zs): @@ -289,13 +331,13 @@ def compute_average_E0s( try: E0s = np.linalg.lstsq(A, B, rcond=None)[0] for i, z in enumerate(z_table.zs): - atomic_energies_dict[theory][z] = E0s[i] + atomic_energies_dict[head][z] = E0s[i] except np.linalg.LinAlgError: logging.warning( "Failed to compute E0s using least squares regression, using the same for all atoms" ) for i, z in enumerate(z_table.zs): - atomic_energies_dict[theory][z] = 0.0 + atomic_energies_dict[head][z] = 0.0 return atomic_energies_dict @@ -321,7 +363,7 @@ def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges - grp["theory"] = data.theory + grp["head"] = data.head def save_AtomicData_to_HDF5(data, i, h5_file) -> None: @@ -344,7 +386,7 @@ def save_AtomicData_to_HDF5(data, i, h5_file) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges - grp["theory"] = data.theory + grp["head"] = data.head def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: @@ -358,7 +400,7 @@ def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> N subgroup["forces"] = write_value(config.forces) subgroup["stress"] = write_value(config.stress) subgroup["virials"] = write_value(config.virials) - subgroup["theory"] = write_value(config.theory) + subgroup["head"] = write_value(config.head) subgroup["dipole"] = write_value(config.dipole) subgroup["charges"] = write_value(config.charges) subgroup["cell"] = write_value(config.cell) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 62be7e95..7a162146 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -17,7 +17,7 @@ from .irreps_tools import ( linear_out_irreps, - mask_theory, + mask_head, reshape_irreps, tp_out_irreps_with_instructions, ) @@ -52,7 +52,7 @@ def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")) self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) def forward( - self, x: torch.Tensor, theories: Optional[torch.Tensor] = None + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -66,26 +66,22 @@ def __init__( MLP_irreps: o3.Irreps, gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), - num_theories: int = 1, + num_heads: int = 1, ): super().__init__() self.hidden_irreps = MLP_irreps - self.num_theories = num_theories + self.num_heads = num_heads self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) def forward( - self, x: torch.Tensor, theories: Optional[torch.Tensor] = None + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) - if ( - hasattr(self, "num_theories") - and self.num_theories > 1 - and theories is not None - ): - x = mask_theory(x, theories, self.num_theories) - return self.linear_2(x) # [n_nodes, len(theories)] + if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) + return self.linear_2(x) # [n_nodes, len(heads)] @compile_mode("script") @@ -153,7 +149,7 @@ def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): self.register_buffer( "atomic_energies", torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), - ) # [n_elements, n_theories] + ) # [n_elements, n_heads] def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] @@ -771,10 +767,9 @@ def __init__(self, scale: float, shift: float): torch.atleast_1d(torch.tensor(shift, dtype=torch.get_default_dtype())), ) - def forward(self, x: torch.Tensor, theory: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: return ( - torch.atleast_1d(self.scale)[theory] * x - + torch.atleast_1d(self.shift)[theory] + torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] ) def __repr__(self): diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 4e7ce9f1..b0960193 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -86,13 +86,9 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: return torch.cat(out, dim=-1) -def mask_theory( - x: torch.Tensor, theory: torch.Tensor, num_theories: int -) -> torch.Tensor: - mask = torch.zeros( - x.shape[0], x.shape[1] // num_theories, num_theories, device=x.device - ) +def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: + mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) idx = torch.arange(mask.shape[0], device=x.device) - mask[idx, :, theory] = 1 + mask[idx, :, head] = 1 mask = mask.permute(0, 2, 1).reshape(x.shape) return x * mask diff --git a/mace/modules/models.py b/mace/modules/models.py index c6cba12c..b6018892 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -61,7 +61,7 @@ def __init__( distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", - theories: Optional[List[str]] = ["Default"], + heads: Optional[List[str]] = ["Default"], ): super().__init__() self.register_buffer( @@ -73,7 +73,7 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) - self.theories = theories + self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions # Embedding @@ -134,7 +134,7 @@ def __init__( self.readouts = torch.nn.ModuleList() self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(theories)}x0e")) + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) ) for i in range(num_interactions - 1): @@ -167,15 +167,15 @@ def __init__( self.readouts.append( NonLinearReadoutBlock( hidden_irreps_out, - (len(theories) * MLP_irreps).simplify(), + (len(heads) * MLP_irreps).simplify(), gate, - o3.Irreps(f"{len(theories)}x0e"), - len(theories), + o3.Irreps(f"{len(heads)}x0e"), + len(heads), ) ) else: self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(theories)}x0e")) + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) ) def forward( @@ -190,10 +190,10 @@ def forward( # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) - print("theory", data["theory"]) + print("head", data["head"]) num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 - node_theories = data["theory"][data["batch"]] + node_heads = data["head"][data["batch"]] displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -215,12 +215,12 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_theories + num_atoms_arange, node_heads ] # print("node e0", node_e0.shape) e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, n_theories] + ) # [n_graphs, n_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -265,9 +265,9 @@ def forward( node_attrs=data["node_attrs"], ) node_feats_list.append(node_feats) - node_energies = readout(node_feats, node_theories)[ - num_atoms_arange, node_theories - ] # [n_nodes, len(theories)] + node_energies = readout(node_feats, node_heads)[ + num_atoms_arange, node_heads + ] # [n_nodes, len(heads)] energy = scatter_sum( src=node_energies, index=data["batch"], @@ -337,7 +337,7 @@ def forward( data["node_attrs"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 num_atoms_arange = torch.arange(data["positions"].shape[0]) - node_theories = data["theory"][data["batch"]] + node_heads = data["head"][data["batch"]] displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -359,11 +359,11 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_theories + num_atoms_arange, node_heads ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, num_theories] + ) # [n_graphs, num_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) @@ -400,7 +400,7 @@ def forward( ) node_feats_list.append(node_feats) node_es_list.append( - readout(node_feats, node_theories)[num_atoms_arange, node_theories] + readout(node_feats, node_heads)[num_atoms_arange, node_heads] ) # {[n_nodes, ], } # Concatenate node features @@ -411,7 +411,7 @@ def forward( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] # print("node_inter_es", node_inter_es.shape) - node_inter_es = self.scale_shift(node_inter_es, node_theories) + node_inter_es = self.scale_shift(node_inter_es, node_heads) # Sum over nodes in graph inter_e = scatter_sum( @@ -526,11 +526,11 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["theory"][data["batch"]] + num_atoms_arange, data["head"][data["batch"]] ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs, n_theories] + ) # [n_graphs, n_heads] # Embeddings node_feats = self.node_embedding(data.node_attrs) @@ -591,7 +591,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: num_atoms_arange = torch.arange(data.positions.shape[0]) # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["theory"][data["batch"]] + num_atoms_arange, data["head"][data["batch"]] ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs @@ -624,7 +624,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es, data["theory"][data["batch"]]) + node_inter_es = self.scale_shift(node_inter_es, data["head"][data["batch"]]) # Sum over nodes in graph inter_e = scatter_sum( @@ -1006,7 +1006,7 @@ def forward( # Atomic energies node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["theory"][data["batch"]] + num_atoms_arange, data["head"][data["batch"]] ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 4e5fec6d..5e0ec72b 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -192,27 +192,25 @@ def compute_mean_std_atomic_inter_energy( atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) avg_atom_inter_es_list = [] - theory_list = [] + head_list = [] for batch in data_loader: node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), batch.theory] + )[torch.arange(batch.num_graphs), batch.head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] avg_atom_inter_es_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } - theory_list.append(batch.theory) + head_list.append(batch.head) avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - theory = torch.cat(theory_list, dim=0) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() # std = to_numpy(torch.std(avg_atom_inter_es)).item() - mean = to_numpy( - scatter_mean(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1) - ) - std = to_numpy(scatter_std(src=avg_atom_inter_es, index=theory, dim=0).squeeze(-1)) + mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) std = _check_non_zero(std) return mean, std @@ -222,11 +220,11 @@ def _compute_mean_std_atomic_inter_energy( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: - theory = batch.theory + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), theory] + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes return atom_energies @@ -240,34 +238,34 @@ def compute_mean_rms_energy_forces( atom_energy_list = [] forces_list = [] - theory_list = [] - theory_batch = [] + head_list = [] + head_batch = [] for batch in data_loader: - theory = batch.theory + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), theory] + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - theory_list.append(theory) - theory_batch.append(theory[batch.batch]) + head_list.append(head) + head_batch.append(head[batch.batch]) atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - theory = torch.cat(theory_list, dim=0) # [total_n_graphs] - theory_batch = torch.cat(theory_batch, dim=0) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] # mean = to_numpy(torch.mean(atom_energies)).item() # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1)) + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) rms = to_numpy( torch.sqrt( - scatter_mean(src=torch.square(forces), index=theory_batch, dim=0).mean(-1) + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) ) ) rms = _check_non_zero(rms) @@ -279,11 +277,11 @@ def _compute_mean_rms_energy_forces( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: - theory = batch.theory + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), theory] + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } forces = batch.forces # {[n_graphs*n_atoms,3], } @@ -314,20 +312,20 @@ def compute_statistics( atom_energy_list = [] forces_list = [] num_neighbors = [] - theory_list = [] + head_list = [] for batch in data_loader: - theory = batch.theory + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), theory] + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - theory_list.append(theory) # {[n_graphs], } + head_list.append(head) # {[n_graphs], } _, receivers = batch.edge_index _, counts = torch.unique(receivers, return_counts=True) @@ -335,14 +333,14 @@ def compute_statistics( atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - theory = torch.cat(theory_list, dim=0) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] # mean = to_numpy(torch.mean(atom_energies)).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=theory, dim=0).squeeze(-1)) - # do the mean for each theory + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + # do the mean for each head # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() rms = to_numpy( - torch.sqrt(scatter_mean(src=torch.square(forces), index=theory, dim=0)) + torch.sqrt(scatter_mean(src=torch.square(forces), index=head, dim=0)) ) avg_num_neighbors = torch.mean( diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 9afa4b12..fa646fbd 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -332,12 +332,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser: required=False, ) parser.add_argument( - "--theories", - help="List of theories in the training set", + "--heads", + help="List of heads in the training set", type=str, default=None, required=False, ) + parser.add_argument( + "--multiheads_finetuning", + help="Boolean value for whether the model is multiheaded", + type=bool, + default=True, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 8e25accc..18617760 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -76,7 +76,7 @@ def load_foundations_elements( """ assert model_foundations.r_max == model.r_max z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - model_theories = model.theories + model_heads = model.heads new_z_table = table num_species_foundations = len(z_table.zs) num_channels_foundation = ( @@ -193,7 +193,7 @@ def load_foundations_elements( model_readouts_zero_linear_weight = ( model_foundations.readouts[0] .linear.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_theories)) + .repeat(1, len(model_heads)) .flatten() .clone() ) @@ -209,7 +209,7 @@ def load_foundations_elements( model_readouts_one_linear_1_weight = ( model_foundations.readouts[1] .linear_1.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_theories)) + .repeat(1, len(model_heads)) .flatten() .clone() ) @@ -220,7 +220,7 @@ def load_foundations_elements( model_readouts_one_linear_2_weight = model_foundations.readouts[ 1 ].linear_2.weight.view(shape_input_1, -1).repeat( - len(model_theories), len(model_theories) + len(model_heads), len(model_heads) ).flatten().clone() / ( ((shape_input_1) / (shape_output_1)) ** 0.5 ) @@ -230,11 +230,11 @@ def load_foundations_elements( if model_foundations.scale_shift is not None: if use_scale: model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( - len(model_theories) + len(model_heads) ).clone() if use_shift: model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( - len(model_theories) + len(model_heads) ).clone() return model diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 092edef5..2811ab6d 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -4,6 +4,7 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +import argparse import ast import dataclasses import json @@ -43,7 +44,7 @@ def get_dataset_from_xyz( charges_key: str = "charges", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" - atomic_energies_dict, all_train_configs, theories = data.load_from_xyz( + atomic_energies_dict, all_train_configs, heads = data.load_from_xyz( file_path=train_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -79,15 +80,15 @@ def get_dataset_from_xyz( "Using random %s%% of training set for validation", 100 * valid_fraction ) train_configs, valid_configs = [], [] - for theory in theories: - all_train_configs_theory = [ - config for config in all_train_configs if config.theory == theory + for head in heads: + all_train_configs_head = [ + config for config in all_train_configs if config.head == head ] - train_configs_theory, valid_configs_theory = data.random_train_valid_split( - all_train_configs_theory, valid_fraction, seed + train_configs_head, valid_configs_head = data.random_train_valid_split( + all_train_configs_head, valid_fraction, seed ) - train_configs.extend(train_configs_theory) - valid_configs.extend(valid_configs_theory) + train_configs.extend(train_configs_head) + valid_configs.extend(valid_configs_head) test_configs = [] if test_path is not None: @@ -108,7 +109,7 @@ def get_dataset_from_xyz( return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), atomic_energies_dict, - theories, + heads, ) @@ -127,7 +128,7 @@ def get_config_type_weights(ct_weights): return config_type_weights -def get_atomic_energies(E0s, train_collection, z_table, theories) -> dict: +def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: if E0s is not None: logging.info( "Atomic Energies not in training file, using command line argument E0s" @@ -140,7 +141,7 @@ def get_atomic_energies(E0s, train_collection, z_table, theories) -> dict: try: assert train_collection is not None atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table, theories + train_collection, z_table, heads ) except Exception as e: raise RuntimeError( @@ -473,3 +474,11 @@ def check_folder_subfolder(folder_path): if os.path.isdir(full_path): return True return False + + +def dict_to_namespace(dictionary): + # Convert the dictionary into an argparse.Namespace + namespace = argparse.Namespace() + for key, value in dictionary.items(): + setattr(namespace, key, value) + return namespace diff --git a/mace/tools/train.py b/mace/tools/train.py index ae5a0877..62013e9d 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -56,7 +56,7 @@ def valid_err_log( error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -66,7 +66,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -76,37 +76,37 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_virials_per_atom={error_virials:.1f} meV" ) elif log_errors == "TotalRMSE": error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A" ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E_per_atom={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, MAE_E={error_e:.1f} meV, MAE_F={error_f:.1f} meV / A" ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_MU_per_atom={error_mu:.2f} mDebye" ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"Theory: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_Mu_per_atom={error_mu:.2f} mDebye" ) @@ -152,14 +152,14 @@ def train( # log validation loss before _any_ training valid_loss = 0.0 for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_theory, eval_metrics = evaluate( + valid_loss_head, eval_metrics = evaluate( model=model, loss_fn=loss_fn, data_loader=valid_loader, output_args=output_args, device=device, ) - valid_loss += valid_loss_theory + valid_loss += valid_loss_head valid_err_log( valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name ) @@ -215,14 +215,14 @@ def train( valid_loss = 0.0 wandb_log_dict = {} for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_theory, eval_metrics = evaluate( + valid_loss_head, eval_metrics = evaluate( model=model_to_evaluate, loss_fn=loss_fn, data_loader=valid_loader, output_args=output_args, device=device, ) - valid_loss += valid_loss_theory + valid_loss += valid_loss_head valid_err_log( valid_loss, eval_metrics, @@ -234,7 +234,7 @@ def train( if log_wandb: wandb_log_dict[valid_loader_name] = { "epoch": epoch, - "valid_loss": valid_loss_theory, + "valid_loss": valid_loss_head, "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], "valid_rmse_f": eval_metrics["rmse_f"], } diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 22f0f2f8..4f6909eb 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -108,7 +108,7 @@ def test_multi_reference(): energy=-1.5, charges=molecule("H2COH").numbers, dipole=np.array([-1.5, 1.5, 2.0]), - theory="MP2", + head="MP2", ) table = tools.AtomicNumberTable([1, 6, 8]) atomic_energies = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) @@ -137,7 +137,7 @@ def test_multi_reference(): radial_type="bessel", atomic_inter_scale=[1.0, 1.0], atomic_inter_shift=[0.0, 0.0], - theories=["MP2", "DFT"], + heads=["MP2", "DFT"], ) model = modules.ScaleShiftMACE(**model_config) calc_foundation = mace_mp(device="cpu", default_dtype="float64") @@ -150,7 +150,7 @@ def test_multi_reference(): max_L=1, ) atomic_data = data.AtomicData.from_config( - config, z_table=table, cutoff=6.0, theories=["MP2", "DFT"] + config, z_table=table, cutoff=6.0, heads=["MP2", "DFT"] ) data_loader = torch_geometric.dataloader.DataLoader( dataset=[atomic_data, atomic_data], diff --git a/tests/test_models.py b/tests/test_models.py index 5eb9f516..ca3e69b5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -222,20 +222,20 @@ def test_mace_multi_reference(): distance_transform=True, pair_repulsion=True, correlation=3, - theories=["Default", "dft"], + heads=["Default", "dft"], # radial_type="chebyshev", atomic_inter_scale=[1.0, 1.0], atomic_inter_shift=[0.0, 0.1], ) model = modules.ScaleShiftMACE(**model_config) model_compiled = jit.compile(model) - config.theory = "Default" - config_rotated.theory = "dft" + config.head = "Default" + config_rotated.head = "dft" atomic_data = data.AtomicData.from_config( - config, z_table=table, cutoff=3.0, theories=["Default", "dft"] + config, z_table=table, cutoff=3.0, heads=["Default", "dft"] ) atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0, theories=["Default", "dft"] + config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] ) data_loader = torch_geometric.dataloader.DataLoader( diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 376739bd..7dca8919 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -355,9 +355,9 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): fitting_configs_ = [] for i, c in enumerate(fitting_configs): if i % 2 == 0: - c.info["theory"] = "DFT" + c.info["head"] = "DFT" else: - c.info["theory"] = "MP2" + c.info["head"] = "MP2" fitting_configs_.append(c) ase.io.write(tmp_path / "fit.xyz", fitting_configs_) @@ -373,7 +373,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["default_dtype"] = "float32" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["theories"] = "['MP2','DFT']" + mace_params["heads"] = "['MP2','DFT']" mace_params["batch_size"] = 2 # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() @@ -432,4 +432,4 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): 0.1669044941663742, 0.119053915143013, ] - assert np.allclose(Es, ref_Es) + assert np.allclose(Es, ref_Es, tol=1e-2) From cbdc546d0c2419aa0917309cc118856f5ede321a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:22:13 +0100 Subject: [PATCH 027/101] fix the interface for multihead --- mace/cli/fine_tuning_select.py | 9 +++------ mace/cli/run_train.py | 16 +++++++++++----- mace/data/utils.py | 1 - mace/tools/arg_parser.py | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index b90b432c..5fe1f7d0 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -99,7 +99,7 @@ def calculate_descriptors( atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict ) -> None: print("Calculating descriptors") - for mol in tqdm(atoms): + for mol in atoms: descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) # average descriptors over atoms for each element descriptors_dict = { @@ -182,7 +182,8 @@ def assemble_descriptors(self) -> np.ndarray: len(self.atoms_list), len(self.species), len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), - ) + ), + dtype=np.float32, ) ) for i, atoms in enumerate(self.atoms_list): @@ -216,10 +217,6 @@ def select_samples( atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] - print( - "Filtering configurations based on the finetuning set," - f"filtering type: combinations, elements: {all_species_ft}" - ) atoms_list_pt = [ x for x in atoms_list_pt diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c39170cb..7c1d84df 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -177,7 +177,13 @@ def main() -> None: if args.multiheads_finetuning: logging.info("Using multiheads finetuning mode") - heads = list(set(["pbe_mp"] + heads)) + if heads is not None: + heads = list(set(["pbe_mp"] + heads)) + args.heads = heads + else: + heads = ["pbe_mp", "Default"] + args.heads = heads + logging.info(f"Using heads: {heads}") try: checkpoint_url = "https://tinyurl.com/mw2wetc5" cache_dir = os.path.expanduser("~/.cache/mace") @@ -214,13 +220,13 @@ def main() -> None: "weight_ft": 1.0, "filtering_type": "combination", "output": f"{cache_dir}/mp_finetuning.xyz", - "descriptors": None, + "descriptors": r"D:\Work\mace_mp\descriptors.npy", "device": args.device, "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) collections_mp, _, _ = get_dataset_from_xyz( - train_path=dataset_mp, + train_path=f"{cache_dir}/mp_finetuning.xyz", valid_path=None, valid_fraction=args.valid_fraction, config_type_weights=config_type_weights, @@ -277,9 +283,9 @@ def main() -> None: else: atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) if args.multiheads_finetuning: - with open("mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: + with open(r"mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: E0s_mp = json.load(file) - atomic_energies_dict["pbe_mp"] = {E0s_mp["pbe"][z] for z in z_table.zs} + atomic_energies_dict["pbe_mp"] = {z: E0s_mp["pbe"][f"{z}"] for z in z_table.zs} if args.model == "AtomicDipolesMACE": atomic_energies = None diff --git a/mace/data/utils.py b/mace/data/utils.py index b12368bf..c55ad86b 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -231,7 +231,6 @@ def load_from_xyz( ) stress_key = "REF_stress" - # Process each atom only once for atoms in atoms_list: if energy_key == "REF_energy": try: diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index fa646fbd..57249044 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -372,7 +372,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_key", help="Key of reference stress in training xyz", type=str, - default="stress", + default="REF_stress", ) parser.add_argument( "--dipole_key", From 757385a39afa5cf61b0a3e73523e3ff73e748874 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:52:41 +0100 Subject: [PATCH 028/101] automatic download of the descriptors --- mace/cli/fine_tuning_select.py | 29 +++++++++++++++-------------- mace/cli/run_train.py | 26 +++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 5fe1f7d0..2cbbb522 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -98,7 +98,7 @@ def parse_args() -> argparse.Namespace: def calculate_descriptors( atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict ) -> None: - print("Calculating descriptors") + logging.info("Calculating descriptors") for mol in atoms: descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) # average descriptors over atoms for each element @@ -164,8 +164,8 @@ def run( """ Run the farthest point sampling algorithm. """ - print(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) - print("n_samples", self.n_samples) + logging.info(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) + logging.info("n_samples", self.n_samples) self.list_index = fpsample.fps_npdu_kdtree_sampling( self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples ) @@ -207,12 +207,12 @@ def select_samples( if args.filtering_type != None: all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) - print( - "Filtering configurations based on the finetuning set," + logging.info( + "Filtering configurations based on the finetuning set, " f"filtering type: combinations, elements: {all_species_ft}" ) if args.descriptors is not None: - print("Loading descriptors") + logging.info("Loading descriptors") descriptors = np.load(args.descriptors, allow_pickle=True) atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): @@ -222,7 +222,6 @@ def select_samples( for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] - else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") atoms_list_pt = [ @@ -233,7 +232,7 @@ def select_samples( else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") if args.descriptors is not None: - print( + logging.info( "Loading descriptors for the pretraining set from {}".format( args.descriptors ) @@ -244,35 +243,37 @@ def select_samples( if args.num_samples is not None and args.num_samples < len(atoms_list_pt): if args.descriptors is None: - print("Calculating descriptors for the pretraining set") + logging.info("Calculating descriptors for the pretraining set") calculate_descriptors(atoms_list_pt, calc, None) descriptors_list = [ atoms.info["mace_descriptors"] for atoms in atoms_list_pt ] - print( + logging.info( "Saving descriptors at {}".format( args.output.replace(".xyz", "descriptors.npy") ) ) np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list) - print("Selecting configurations using Farthest Point Sampling") + logging.info("Selecting configurations using Farthest Point Sampling") fps_pt = FPS(atoms_list_pt, args.num_samples) idx_pt = fps_pt.run() - print(f"Selected {len(idx_pt)} configurations") + logging.info(f"Selected {len(idx_pt)} configurations") atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True atoms.info["config_weight"] = args.weight_pt + atoms.info["mace_descriptors"] = None if args.head_pt is not None: atoms.info["head"] = args.head_pt - print("Saving the selected configurations") + logging.info("Saving the selected configurations") ase.io.write(args.output, atoms_list_pt, format="extxyz") - print("Saving a combined XYZ file") + logging.info("Saving a combined XYZ file") for atoms in atoms_list_ft: atoms.info["pretrained"] = False atoms.info["config_weight"] = args.weight_ft + atoms.info["mace_descriptors"] = None if args.head_ft is not None: atoms.info["head"] = args.head_ft atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 7c1d84df..b509ea1c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -186,6 +186,7 @@ def main() -> None: logging.info(f"Using heads: {heads}") try: checkpoint_url = "https://tinyurl.com/mw2wetc5" + descriptors_url = "https://tinyurl.com/mpe7br4d" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c @@ -193,6 +194,12 @@ def main() -> None: if c.isalnum() or c in "_" ) cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c + for c in os.path.basename(descriptors_url) + if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" if not os.path.isfile(cached_dataset_path): os.makedirs(cache_dir, exist_ok=True) # download and save to disk @@ -205,9 +212,26 @@ def main() -> None: f"Dataset download failed, please check the URL {checkpoint_url}" ) logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info( + f"Materials Project descriptors to {cached_descriptors_path}" + ) dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path msg = f"Using Materials Project dataset with {dataset_mp}" logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) args_samples = { "configs_pt": dataset_mp, "configs_ft": args.train_file, @@ -220,7 +244,7 @@ def main() -> None: "weight_ft": 1.0, "filtering_type": "combination", "output": f"{cache_dir}/mp_finetuning.xyz", - "descriptors": r"D:\Work\mace_mp\descriptors.npy", + "descriptors": descriptors_mp, "device": args.device, "default_dtype": args.default_dtype, } From 4bef9937ed58b39a7d0aeebca10a9e113df6d419 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 12:03:20 +0100 Subject: [PATCH 029/101] fix vanilla training --- mace/cli/run_train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b509ea1c..2a85a320 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -119,6 +119,8 @@ def main() -> None: f"Using foundation model {args.foundation_model} as initial checkpoint." ) args.r_max = model_foundation.r_max.item() + else: + args.multiheads_finetuning = False if args.statistics_file is not None: with open(args.statistics_file, "r") as f: # pylint: disable=W1514 From b2fb45859146a4a4103657f852370f9a92bb9aca Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 19:36:12 +0100 Subject: [PATCH 030/101] add weight pt head to argparser --- mace/cli/run_train.py | 2 +- mace/tools/arg_parser.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 2a85a320..6b4d5290 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -242,7 +242,7 @@ def main() -> None: "model": args.foundation_model, "head_pt": "pbe_mp", "head_ft": "Default", - "weight_pt": 1.0, + "weight_pt": args.weight_pt_head, "weight_ft": 1.0, "filtering_type": "combination", "output": f"{cache_dir}/mp_finetuning.xyz", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 57249044..b05816dd 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -344,6 +344,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=bool, default=True, ) + parser.add_argument( + "--weight_pt_head", + help="Weight of the pretrained head in the loss function", + type=float, + default=1.0, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", From d69006391595af6165ec7e1282f655fd3ef52ef2 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 19:57:17 +0100 Subject: [PATCH 031/101] Update atomic_data.py --- mace/data/atomic_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index f157cecd..9344546e 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -128,8 +128,7 @@ def from_config( try: head = torch.tensor(heads.index(config.head), dtype=torch.long) except: - print(f"head {config.head} not found in {heads}") - head = torch.tensor(0, dtype=torch.long) + head = torch.tensor(len(heads) - 1, dtype=torch.long) cell = ( torch.tensor(config.cell, dtype=torch.get_default_dtype()) From 7697261679c0041a4ebb3741c093fde3cff5940c Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:48:37 +0100 Subject: [PATCH 032/101] remove non compatible typing hint --- mace/cli/fine_tuning_select.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 2cbbb522..922caf25 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -4,7 +4,7 @@ import argparse import logging -import typing as t +from typing import List import ase.data @@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace: def calculate_descriptors( - atoms: t.List[ase.Atoms | ase.Atom], calc: MACECalculator, cutoffs: None | dict + atoms: List[ase.Atoms], calc: MACECalculator, cutoffs: None | dict ) -> None: logging.info("Calculating descriptors") for mol in atoms: From 7cd71d82611178901819cea7d87c5711cb3e1284 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 1 May 2024 16:48:36 -0400 Subject: [PATCH 033/101] Fix bug that overwrote REF_* keys when those were the explicitly specified keys for the training reference quantities --- mace/data/utils.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index c55ad86b..4dd96287 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -212,45 +212,52 @@ def load_from_xyz( ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") + energy_from_calc = False + forces_from_calc = False + stress_from_calc = False + # Perform initial checks and log warnings if energy_key == "energy": logging.info( - "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to 'REF_energy'" + "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to '_REF_energy'" ) - energy_key = "REF_energy" + energy_from_calc = True + energy_key = "_REF_energy" if forces_key == "forces": logging.info( - "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to 'REF_forces'" + "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to '_REF_forces'" ) - forces_key = "REF_forces" + forces_from_calc = True + forces_key = "_REF_forces" if stress_key == "stress": logging.info( - "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to 'REF_stress'" + "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to '_REF_stress'" ) - stress_key = "REF_stress" + stress_from_calc = True + stress_key = "_REF_stress" for atoms in atoms_list: - if energy_key == "REF_energy": + if energy_from_calc: try: - atoms.info["REF_energy"] = atoms.get_potential_energy() + atoms.info["_REF_energy"] = atoms.get_potential_energy() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract energy: {e}") - atoms.info["REF_energy"] = None + atoms.info["_REF_energy"] = None - if forces_key == "REF_forces": + if forces_from_calc: try: - atoms.info["REF_forces"] = atoms.get_forces() + atoms.info["_REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract forces: {e}") - atoms.info["REF_forces"] = None + atoms.info["_REF_forces"] = None - if stress_key == "REF_stress": + if stress_from_calc: try: - atoms.info["REF_stress"] = atoms.get_stress() + atoms.info["_REF_stress"] = atoms.get_stress() except Exception as e: # pylint: disable=W0703 - atoms.info["REF_stress"] = None + atoms.info["_REF_stress"] = None if not isinstance(atoms_list, list): atoms_list = [atoms_list] From 5016abfe4458192675f0983e7f5b52332de21fbe Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 2 May 2024 17:45:57 +0100 Subject: [PATCH 034/101] fix urls and extract E0s --- mace/cli/run_train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 6b4d5290..cb098eb6 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -43,6 +43,7 @@ load_foundations_elements, extract_config_mace_model, ) +from mace.tools.utils import AtomicNumberTable def main() -> None: @@ -187,8 +188,8 @@ def main() -> None: args.heads = heads logging.info(f"Using heads: {heads}") try: - checkpoint_url = "https://tinyurl.com/mw2wetc5" - descriptors_url = "https://tinyurl.com/mpe7br4d" + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c @@ -309,9 +310,21 @@ def main() -> None: else: atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) if args.multiheads_finetuning: - with open(r"mace\calculators\foundations_models\mp_vasp_e0.json", "r") as file: - E0s_mp = json.load(file) - atomic_energies_dict["pbe_mp"] = {z: E0s_mp["pbe"][f"{z}"] for z in z_table.zs} + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + logging.info( + "Using atomic energies from foundation model for multiheads finetuning" + ) + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + atomic_energies_dict["pbe_mp"] = { + z: model_foundation.atomic_energies_fn.atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } if args.model == "AtomicDipolesMACE": atomic_energies = None From 7e482573fa98620a99109c68336404c2d26f9267 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 2 May 2024 18:30:55 +0100 Subject: [PATCH 035/101] remove non standard type hints --- mace/cli/fine_tuning_select.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 922caf25..a9349a6f 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace: def calculate_descriptors( - atoms: List[ase.Atoms], calc: MACECalculator, cutoffs: None | dict + atoms: List[ase.Atoms], calc: MACECalculator, cutoffs: None ) -> None: logging.info("Calculating descriptors") for mol in atoms: From 1bc9077d11fec64c108c612dec1851d24ab6180e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 2 May 2024 19:59:23 +0100 Subject: [PATCH 036/101] re order the dict for E0s --- mace/cli/run_train.py | 2 +- mace/tools/scripts_utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index cb098eb6..a0242afc 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -348,7 +348,7 @@ def main() -> None: # atomic_energies: np.ndarray = np.array( # [atomic_energies_dict[z] for z in z_table.zs] # ) - atomic_energies = dict_to_array(atomic_energies_dict) + atomic_energies = dict_to_array(atomic_energies_dict, args.heads) logging.info(f"Atomic energies: {atomic_energies.tolist()}") if args.train_file.endswith(".xyz"): diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 2811ab6d..e03d754b 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -234,7 +234,7 @@ def custom_key(key): return (2, key) -def dict_to_array(data): +def dict_to_array(data, heads): if not all(isinstance(value, dict) for value in data.values()): return np.array(list(data.values())) unique_keys = set() @@ -243,10 +243,11 @@ def dict_to_array(data): unique_keys = list(unique_keys) sorted_keys = sorted([int(key) for key in unique_keys]) result_array = np.zeros((len(data), len(sorted_keys))) - for default_index, (_, inner_dict) in enumerate(data.items()): + for _, (head_name, inner_dict) in enumerate(data.items()): for key, value in inner_dict.items(): key_index = sorted_keys.index(int(key)) - result_array[default_index][key_index] = value + head_index = heads.index(head_name) + result_array[head_index][key_index] = value return np.squeeze(result_array) From 44e7a2c5644b850f2bbc1a871148ee04e830f3bd Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Thu, 2 May 2024 15:42:07 -0400 Subject: [PATCH 037/101] Make sure check for patience does not try to use swa.start if it is None --- mace/tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 62013e9d..87a958f9 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -244,12 +244,12 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: + if patience_counter >= patience and (swa.start is not None and epoch < swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: + elif patience_counter >= patience and (swa.start is None or epoch >= swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" ) From fb35950bea8d720c2dde923c9b4fbd153d6b92ab Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 2 May 2024 21:21:53 +0100 Subject: [PATCH 038/101] fix ordering of heads --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index a0242afc..fa548bf5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -181,7 +181,7 @@ def main() -> None: if args.multiheads_finetuning: logging.info("Using multiheads finetuning mode") if heads is not None: - heads = list(set(["pbe_mp"] + heads)) + heads = list(dict.fromkeys(["pbe_mp"] + heads)) args.heads = heads else: heads = ["pbe_mp", "Default"] From a208f196e0d4b26e8b34c6a85d83d6e76884dc53 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Sun, 5 May 2024 11:55:14 -0400 Subject: [PATCH 039/101] Actually store multihead fine-tuning pbe mp forces in atoms.arrays, rather than incorrect current storage in atoms.info store configs selected when multihead fine-tuning in local, tag dependent filename, rather than fixed filename in ~/.cache/mace --- mace/cli/run_train.py | 4 ++-- mace/data/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index fa548bf5..6497f7cf 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -246,14 +246,14 @@ def main() -> None: "weight_pt": args.weight_pt_head, "weight_ft": 1.0, "filtering_type": "combination", - "output": f"{cache_dir}/mp_finetuning.xyz", + "output": f"mp_finetuning-{tag}.xyz", "descriptors": descriptors_mp, "device": args.device, "default_dtype": args.default_dtype, } select_samples(dict_to_namespace(args_samples)) collections_mp, _, _ = get_dataset_from_xyz( - train_path=f"{cache_dir}/mp_finetuning.xyz", + train_path=f"mp_finetuning-{tag}.xyz", valid_path=None, valid_fraction=args.valid_fraction, config_type_weights=config_type_weights, diff --git a/mace/data/utils.py b/mace/data/utils.py index 4dd96287..22321017 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -248,10 +248,10 @@ def load_from_xyz( if forces_from_calc: try: - atoms.info["_REF_forces"] = atoms.get_forces() + atoms.arrays["_REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract forces: {e}") - atoms.info["_REF_forces"] = None + atoms.arrays["_REF_forces"] = None if stress_from_calc: try: From bb10932cd5fffdfb6a465b4a587bacdab9b2bd87 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Mon, 6 May 2024 17:12:28 -0400 Subject: [PATCH 040/101] Better fix for PR #405, fix patience check when swa is not active --- mace/tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 87a958f9..32231acf 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -244,12 +244,12 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and (swa.start is not None and epoch < swa.start): + if patience_counter >= patience and (swa is not None and epoch < swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) epoch = swa.start - elif patience_counter >= patience and (swa.start is None or epoch >= swa.start): + elif patience_counter >= patience and (swa is None or epoch >= swa.start): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" ) From f1be21e1fbd742ab53448b3851b5e1d50abfc04f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 16 May 2024 23:04:33 +0100 Subject: [PATCH 041/101] improving the pt selection and fix E0 bug --- mace/cli/fine_tuning_select.py | 23 +++++++++++++++++++++-- mace/cli/run_train.py | 2 +- mace/tools/arg_parser.py | 6 ++++++ mace/tools/scripts_utils.py | 9 ++++++++- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index a9349a6f..19e44d7d 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -92,6 +92,7 @@ def parse_args() -> argparse.Namespace: type=float, default=1.0, ) + parser.add_argument("--seed", help="random seed", type=int, default=42) return parser.parse_args() @@ -197,6 +198,8 @@ def assemble_descriptors(self) -> np.ndarray: def select_samples( args: argparse.Namespace, ) -> None: + np.random.seed(args.seed) + torch.manual_seed(args.seed) if args.model in ["small", "medium", "large"]: calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) else: @@ -217,18 +220,34 @@ def select_samples( atoms_list_pt = ase.io.read(args.configs_pt, index=":") for i, atoms in enumerate(atoms_list_pt): atoms.info["mace_descriptors"] = descriptors[i] - atoms_list_pt = [ + atoms_list_pt_filtered = [ x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") - atoms_list_pt = [ + atoms_list_pt_filtered = [ x for x in atoms_list_pt if filter_atoms(x, all_species_ft, "combinations") ] + if len(atoms_list_pt_filtered) <= args.num_samples: + logging.info( + "Number of configurations after filtering is less than the number of samples, " + "selecting random configurations, for the rest." + ) + atoms_list_pt_minus_filtered = [ + x for x in atoms_list_pt if x not in atoms_list_pt_filtered + ] + atoms_list_pt_random = np.random.choice( + atoms_list_pt_minus_filtered, + args.num_samples - len(atoms_list_pt_filtered), + ).tolist() + atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random + else: + atoms_list_pt = atoms_list_pt_filtered + else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") if args.descriptors is not None: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 6497f7cf..462966d7 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -238,7 +238,7 @@ def main() -> None: args_samples = { "configs_pt": dataset_mp, "configs_ft": args.train_file, - "num_samples": 1000, + "num_samples": args.num_samples_pt, "seed": args.seed, "model": args.foundation_model, "head_pt": "pbe_mp", diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index b05816dd..abb550e1 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -350,6 +350,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=float, default=1.0, ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=1000, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index e03d754b..9254abb3 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -153,7 +153,14 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: atomic_energies_dict = json.load(open(E0s, "r")) else: try: - atomic_energies_dict = ast.literal_eval(E0s) + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = {"Default": atomic_energies_eval} + else: + atomic_energies_dict = atomic_energies_eval assert isinstance(atomic_energies_dict, dict) except Exception as e: raise RuntimeError( From e4ac49818c9dc581c5167155b7b659bca0c9064e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 4 Jun 2024 18:55:07 +0100 Subject: [PATCH 042/101] add tqdm and fpsample dep --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index d3899885..41379100 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,8 @@ install_requires = torchmetrics python-hostlist configargparse + tqdm + fpsample # for plotting: matplotlib pandas From 0842e7c5cea983e51e58e3275ce83a40c67539ce Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Tue, 4 Jun 2024 16:51:24 -0400 Subject: [PATCH 043/101] get rid of all stress/n_atoms --- mace/tools/train.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 32231acf..9306a171 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -60,13 +60,13 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress_per_atom"] is not None + and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress_per_atom={error_stress:.1f} meV / A^3" + f"head: {valid_loader_name}, Epoch {epoch}: loss={valid_loss:.4f}, RMSE_E_per_atom={error_e:.1f} meV, RMSE_F={error_f:.1f} meV / A, RMSE_stress={error_stress:.1f} meV / A^3" ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -405,7 +405,6 @@ def __init__(self, loss_fn: torch.nn.Module): "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") self.add_state( "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) @@ -434,10 +433,6 @@ def update(self, batch, output): # pylint: disable=arguments-differ if output.get("stress") is not None and batch.stress is not None: self.stress_computed += 1.0 self.delta_stress.append(batch.stress - output["stress"]) - self.delta_stress_per_atom.append( - (batch.stress - output["stress"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) if output.get("virials") is not None and batch.virials is not None: self.virials_computed += 1.0 self.delta_virials.append(batch.virials - output["virials"]) @@ -480,10 +475,8 @@ def compute(self): aux["q95_f"] = compute_q95(delta_fs) if self.stress_computed: delta_stress = self.convert(self.delta_stress) - delta_stress_per_atom = self.convert(self.delta_stress_per_atom) aux["mae_stress"] = compute_mae(delta_stress) aux["rmse_stress"] = compute_rmse(delta_stress) - aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) aux["q95_stress"] = compute_q95(delta_stress) if self.virials_computed: delta_virials = self.convert(self.delta_virials) From a7b32da7a2fa8da54f62c280d8c51a02fe3abc71 Mon Sep 17 00:00:00 2001 From: JamesDarby Date: Mon, 3 Jun 2024 22:46:28 +0000 Subject: [PATCH 044/101] remove n_atoms factor --- mace/modules/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index d1f8becd..b3421ef5 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor: # energy: [n_graphs, ] configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,] return torch.mean( configs_weight * configs_stress_weight - * torch.square((ref["stress"] - pred["stress"]) / num_atoms) + * torch.square(ref["stress"] - pred["stress"]) ) # [] From d52374616dc0e7ab61c140a72b58bea0f8211155 Mon Sep 17 00:00:00 2001 From: James Darby Date: Tue, 4 Jun 2024 11:23:55 +0100 Subject: [PATCH 045/101] updated tests --- tests/test_run_train.py | 84 ++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 7dca8919..7109744d 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -107,30 +107,30 @@ def test_run_train(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 ref_Es = [ 0.0, 0.0, - -0.03911274694160493, - -0.0913651377675312, - -0.14973695873658766, - -0.0664839502025434, - -0.09968814898703926, - 0.1248460531971883, - -0.0647495831154953, - -0.14589298347245963, - 0.12918668431788108, - -0.13996496272772996, - -0.053211348522482806, - 0.07845141245421094, - -0.08901520083723416, - -0.15467129065263446, - 0.007727727865546765, - -0.04502061132025605, - -0.035848783030374, - -0.24410687104937906, - -0.0839034724949955, - -0.14756571357354326, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545 ] assert np.allclose(Es, ref_Es) @@ -178,30 +178,30 @@ def test_run_train_missing_data(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 ref_Es = [ 0.0, 0.0, - -0.05449966431966507, - -0.11237663925685797, - 0.03914539466246801, - -0.07500800414261456, - -0.13471106701173396, - 0.02937255038020199, - -0.0652196693921633, - -0.14946129637190012, - 0.19412338220281133, - -0.13546947741234333, - -0.05235148626886153, - -0.04957190959243316, - -0.07081384032242896, - -0.24575839901841345, - -0.0020512332640394916, - -0.038630330106902526, - -0.13621347044601181, - -0.2338465954158298, - -0.11777474787291177, - -0.14895508008918812, + -0.05464025113696155, + -0.11272131295940478, + 0.039200919331076826, + -0.07517990972827505, + -0.13504202474582666, + 0.0292022872055344, + -0.06541099574579018, + -0.1497824717832886, + 0.19397709360828813, + -0.13587609467143014, + -0.05242956276828463, + -0.0504862057364953, + -0.07095795959430119, + -0.2463753796753703, + -0.002031543147676121, + -0.03864918790300681, + -0.13680153117705554, + -0.23418951968636786, + -0.11790833839379238, + -0.14930562311066484 ] assert np.allclose(Es, ref_Es) From 21e67164878bcfe9747b8933be206dfb415dc795 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Wed, 5 Jun 2024 12:55:42 -0400 Subject: [PATCH 046/101] Don't apply np.random.choice to list(Atoms) since it thinks it's multidimensional --- mace/cli/fine_tuning_select.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 19e44d7d..0435e5c3 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -234,17 +234,19 @@ def select_samples( ] if len(atoms_list_pt_filtered) <= args.num_samples: logging.info( - "Number of configurations after filtering is less than the number of samples, " - "selecting random configurations, for the rest." + f"Number of configurations after filtering {len(atoms_list_pt_filtered} " + f"is less than the number of samples {args.num_samples}, " + "selecting random configurations for the rest." ) atoms_list_pt_minus_filtered = [ x for x in atoms_list_pt if x not in atoms_list_pt_filtered ] - atoms_list_pt_random = np.random.choice( - atoms_list_pt_minus_filtered, + atoms_list_pt_random_inds = np.random.choice( + list(range(len(atoms_list_pt_minus_filtered))), args.num_samples - len(atoms_list_pt_filtered), - ).tolist() - atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_random + replace=False + ) + atoms_list_pt = atoms_list_pt_filtered + [atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds] else: atoms_list_pt = atoms_list_pt_filtered From 9616a412b1f43219ec7d60ae3607dbb7f8b314fb Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Thu, 6 Jun 2024 08:20:49 -0400 Subject: [PATCH 047/101] When printing validation loss during training, pass loss for head, not partial sum over heads calculated so far --- mace/tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 32231acf..441df428 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -161,7 +161,7 @@ def train( ) valid_loss += valid_loss_head valid_err_log( - valid_loss, eval_metrics, logger, log_errors, None, valid_loader_name + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name ) while epoch < max_num_epochs: @@ -224,7 +224,7 @@ def train( ) valid_loss += valid_loss_head valid_err_log( - valid_loss, + valid_loss_head, eval_metrics, logger, log_errors, From 54636560feaa4c83872d366378f8549954cf44dc Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:47:21 +0100 Subject: [PATCH 048/101] fix deterministic valid --- mace/cli/fine_tuning_select.py | 8 +++++--- mace/cli/run_train.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 0435e5c3..fa2ed78f 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -234,7 +234,7 @@ def select_samples( ] if len(atoms_list_pt_filtered) <= args.num_samples: logging.info( - f"Number of configurations after filtering {len(atoms_list_pt_filtered} " + f"Number of configurations after filtering {len(atoms_list_pt_filtered)} " f"is less than the number of samples {args.num_samples}, " "selecting random configurations for the rest." ) @@ -244,9 +244,11 @@ def select_samples( atoms_list_pt_random_inds = np.random.choice( list(range(len(atoms_list_pt_minus_filtered))), args.num_samples - len(atoms_list_pt_filtered), - replace=False + replace=False, ) - atoms_list_pt = atoms_list_pt_filtered + [atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds] + atoms_list_pt = atoms_list_pt_filtered + [ + atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds + ] else: atoms_list_pt = atoms_list_pt_filtered diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 462966d7..b946c78d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -434,7 +434,7 @@ def main() -> None: dataset=valid_set, batch_size=args.valid_batch_size, sampler=valid_samplers[head] if args.distributed else None, - shuffle=(valid_sampler is None), + shuffle=False drop_last=False, pin_memory=args.pin_memory, num_workers=args.num_workers, From 122a50cecf9d24f6786a7051d68210d8963b8859 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:48:58 +0100 Subject: [PATCH 049/101] fix valid --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b946c78d..8aec44df 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -434,7 +434,7 @@ def main() -> None: dataset=valid_set, batch_size=args.valid_batch_size, sampler=valid_samplers[head] if args.distributed else None, - shuffle=False + shuffle=False, drop_last=False, pin_memory=args.pin_memory, num_workers=args.num_workers, From bce718a3fc90d49c908329b93105a40144010a70 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 10 Jun 2024 09:56:59 +0100 Subject: [PATCH 050/101] make weighted huber default loss for finetuning --- mace/cli/run_train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8aec44df..478d0a9b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -552,7 +552,10 @@ def main() -> None: model_config["atomic_inter_shift"] = [args.mean] * len(heads) model_config["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" + args.loss = "universal" model_config["heads"] = args.heads + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") else: logging.info("Building model") if args.num_channels is not None and args.max_L is not None: From 3f65c29b08ce79bd6de309b016b01870f70d90a8 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 10 Jun 2024 10:37:43 +0100 Subject: [PATCH 051/101] fix order of loss multihead --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 478d0a9b..866ffeb5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -180,6 +180,7 @@ def main() -> None: if args.multiheads_finetuning: logging.info("Using multiheads finetuning mode") + args.loss = "universal" if heads is not None: heads = list(dict.fromkeys(["pbe_mp"] + heads)) args.heads = heads @@ -552,7 +553,6 @@ def main() -> None: model_config["atomic_inter_shift"] = [args.mean] * len(heads) model_config["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" - args.loss = "universal" model_config["heads"] = args.heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") From 0bc5445b566b84fb4e0a189e13e376889f91a31c Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:36:26 +0100 Subject: [PATCH 052/101] add fps as optional dep --- mace/cli/fine_tuning_select.py | 18 ++++++++++++++---- setup.cfg | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index fa2ed78f..9b664a46 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -18,7 +18,11 @@ from mace import data import pandas as pd from mace.tools import torch_geometric, torch_tools, utils -import fpsample + +try: + import fpsample +except ImportError: + logging.error("fpsample not found, to use FPS, install using pip install fpsample") def parse_args() -> argparse.Namespace: @@ -278,9 +282,15 @@ def select_samples( ) np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list) logging.info("Selecting configurations using Farthest Point Sampling") - fps_pt = FPS(atoms_list_pt, args.num_samples) - idx_pt = fps_pt.run() - logging.info(f"Selected {len(idx_pt)} configurations") + try: + fps_pt = FPS(atoms_list_pt, args.num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + except Exception as e: + logging.error(f"FPS failed, selecting random configurations instead: {e}") + idx_pt = np.random.choice( + list(range(len(atoms_list_pt)), args.num_samples, replace=False) + ) atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] diff --git a/setup.cfg b/setup.cfg index 41379100..63bb8bb3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,6 @@ install_requires = python-hostlist configargparse tqdm - fpsample # for plotting: matplotlib pandas @@ -44,6 +43,7 @@ console_scripts = [options.extras_require] wandb = wandb +fpsample = fpsample dev = black isort From 66663e0452e992b8e2bb203126b554122304e5f4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 17 Jun 2024 15:36:58 +0100 Subject: [PATCH 053/101] clean code import --- mace/cli/fine_tuning_select.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 9b664a46..e9a8f51b 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -13,11 +13,6 @@ import torch from mace.calculators import MACECalculator, mace_mp -from tqdm import tqdm - -from mace import data -import pandas as pd -from mace.tools import torch_geometric, torch_tools, utils try: import fpsample From 254270266683110e816f780d289172c0571f8cc4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 17 Jun 2024 16:46:31 +0100 Subject: [PATCH 054/101] Update setup.cfg numpy<2.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 63bb8bb3..806df650 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ python_requires = >=3.7 install_requires = torch>=1.12 e3nn==0.4.4 - numpy + numpy<2.0 opt_einsum ase torch-ema From b2b053e7d89562a43adcfe70324a87b26594e3a0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 18 Jun 2024 11:23:28 +0100 Subject: [PATCH 055/101] fix the stress_key in test --- mace/tools/scripts_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 9254abb3..fa430380 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -97,6 +97,7 @@ def get_dataset_from_xyz( config_type_weights=config_type_weights, energy_key=energy_key, forces_key=forces_key, + stress_key=stress_key, dipole_key=dipole_key, charges_key=charges_key, extract_atomic_energies=False, From c00d6d75e226be4fc021b01dded7d44834db92db Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 18 Jun 2024 14:00:59 +0100 Subject: [PATCH 056/101] fix no stress compute universal loss --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 866ffeb5..17aff19c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -518,7 +518,7 @@ def main() -> None: # Selecting outputs compute_virials = False - if args.loss in ("stress", "virials", "huber"): + if args.loss in ("stress", "virials", "huber", "universal"): compute_virials = True args.compute_stress = True args.error_table = "PerAtomRMSEstressvirials" From 1cd818996a551944f4d6a9f229543c2d2e132c5e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:09:21 +0100 Subject: [PATCH 057/101] lint fix finetuning select --- mace/cli/fine_tuning_select.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index e9a8f51b..8bc4abd9 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -17,7 +17,7 @@ try: import fpsample except ImportError: - logging.error("fpsample not found, to use FPS, install using pip install fpsample") + pass def parse_args() -> argparse.Namespace: @@ -110,7 +110,7 @@ def calculate_descriptors( def filter_atoms( - atoms: ase.Atoms, element_subset: list[str], filtering_type: str + atoms: ase.Atoms, element_subset: List[str], filtering_type: str ) -> bool: """ Filters atoms based on the provided filtering type and element subset. @@ -149,7 +149,7 @@ def filter_atoms( class FPS: - def __init__(self, atoms_list: list[ase.Atoms], n_samples: int): + def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): self.n_samples = n_samples self.atoms_list = atoms_list self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) @@ -160,7 +160,7 @@ def __init__(self, atoms_list: list[ase.Atoms], n_samples: int): def run( self, - ) -> list[int]: + ) -> List[int]: """ Run the farthest point sampling algorithm. """ From f04d8f40d2b01bab7b02545248681351305cdafb Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:47:01 +0100 Subject: [PATCH 058/101] fix inf stress in universal loss nonperiodic --- mace/modules/loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index b3421ef5..0357e2e2 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -273,12 +273,13 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] + pred_stress = torch.nan_to_num(pred["stress"], nan=0.0, posinf=0.0, neginf=0.0) return ( self.energy_weight * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + self.forces_weight * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + + self.stress_weight * self.huber_loss(ref["stress"], pred_stress) ) def __repr__(self): From e2d9db9918e6a980a09bcf6c1b42089b75e09a4f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:08:07 +0100 Subject: [PATCH 059/101] fix stress inf in universal loss not pbc --- mace/modules/loss.py | 3 +-- mace/modules/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 0357e2e2..b3421ef5 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -273,13 +273,12 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] - pred_stress = torch.nan_to_num(pred["stress"], nan=0.0, posinf=0.0, neginf=0.0) return ( self.energy_weight * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + self.forces_weight * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred_stress) + + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) ) def __repr__(self): diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 5e0ec72b..1f9807c6 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -64,7 +64,8 @@ def compute_forces_virials( cell[:, 0, :], torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) - stress = virials / volume.view(-1, 1, 1) + stress = virials / (volume.view(-1, 1, 1) + 1e-16) + stress = torch.where(torch.abs(stress) > 1e10, stress, torch.zeros_like(stress)) if forces is None: forces = torch.zeros_like(positions) if virials is None: @@ -122,7 +123,6 @@ def get_outputs( compute_stress: bool = True, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if (compute_virials or compute_stress) and displacement is not None: - # forces come for free forces, virials, stress = compute_forces_virials( energy=energy, positions=positions, From edc6b74a5b2b6259cd636990c5258fe6a52d2fe9 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:11:13 +0100 Subject: [PATCH 060/101] consider the last head for ckpt --- mace/tools/train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 61b1dce0..ae94e169 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -159,10 +159,10 @@ def train( output_args=output_args, device=device, ) - valid_loss += valid_loss_head valid_err_log( valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name ) + valid_loss = valid_loss_head # consider only the last head for the checkpoint while epoch < max_num_epochs: # LR scheduler and SWA update @@ -244,12 +244,16 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and (swa is not None and epoch < swa.start): + if patience_counter >= patience and ( + swa is not None and epoch < swa.start + ): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement and starting swa" ) epoch = swa.start - elif patience_counter >= patience and (swa is None or epoch >= swa.start): + elif patience_counter >= patience and ( + swa is None or epoch >= swa.start + ): logging.info( f"Stopping optimization after {patience_counter} epochs without improvement" ) From 1c10329fb04de3176d9934118f91d4c8f6d9cc56 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:54:30 +0100 Subject: [PATCH 061/101] fix stress masking --- mace/modules/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 1f9807c6..1f88aad0 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -65,7 +65,7 @@ def compute_forces_virials( torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), ).unsqueeze(-1) stress = virials / (volume.view(-1, 1, 1) + 1e-16) - stress = torch.where(torch.abs(stress) > 1e10, stress, torch.zeros_like(stress)) + stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) if forces is None: forces = torch.zeros_like(positions) if virials is None: From 2b6cc052aac951282208b52e3517d75fe5acea80 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 8 Jul 2024 16:39:14 +0100 Subject: [PATCH 062/101] throw error if E0s is average and multihead --- mace/cli/run_train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 17aff19c..a0304010 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -91,6 +91,10 @@ def main() -> None: device = tools.init_device(args.device) if args.foundation_model is not None: + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" if args.foundation_model in ["small", "medium", "large"]: logging.info( f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." From 87a14f4f7b742d6f6d2904f8eee4dc3217b1e0b1 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:19:08 +0100 Subject: [PATCH 063/101] fix nn slabs --- mace/data/neighborhood.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index b5f2b00f..293576af 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -27,17 +27,18 @@ def get_neighborhood( max_positions = np.max(np.absolute(positions)) + 1 # Extend cell in non-periodic directions # For models with more than 5 layers, the multiplicative constant needs to be increased. + temp_cell = np.copy(cell) if not pbc_x: - cell[:, 0] = max_positions * 5 * cutoff * identity[:, 0] + temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] if not pbc_y: - cell[:, 1] = max_positions * 5 * cutoff * identity[:, 1] + temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] if not pbc_z: - cell[:, 2] = max_positions * 5 * cutoff * identity[:, 2] + temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] sender, receiver, unit_shifts = neighbour_list( quantities="ijS", pbc=pbc, - cell=cell, + cell=temp_cell, positions=positions, cutoff=cutoff, # self_interaction=True, # we want edges from atom to itself in different periodic images From 3120bc240495cf0d0a724d3729549d9b417aa467 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:12:09 +0100 Subject: [PATCH 064/101] fix swa universal --- mace/cli/run_train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index a0304010..34d9842c 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -791,6 +791,16 @@ def main() -> None: logging.info( f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) else: loss_fn_energy = modules.WeightedEnergyForcesLoss( energy_weight=args.swa_energy_weight, From a386d997ae5675a9129d87cd53e2ce435871510a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:42:52 +0100 Subject: [PATCH 065/101] update the volume computation --- mace/modules/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 1f88aad0..de6fdd49 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -59,11 +59,7 @@ def compute_forces_virials( stress = torch.zeros_like(displacement) if compute_stress and virials is not None: cell = cell.view(-1, 3, 3) - volume = torch.einsum( - "zi,zi->z", - cell[:, 0, :], - torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), - ).unsqueeze(-1) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) stress = virials / (volume.view(-1, 1, 1) + 1e-16) stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) if forces is None: From 436778d26105355043b7d2b8fb089fe3c5928705 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:16:30 +0100 Subject: [PATCH 066/101] fix merge bugs --- mace/cli/run_train.py | 18 +++++++++--------- mace/modules/models.py | 7 ------- mace/tools/scripts_utils.py | 2 +- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index a4130454..f9bd80e1 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -576,18 +576,18 @@ def run(args: argparse.Namespace) -> None: # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: logging.info("Building model") - model_config = extract_config_mace_model(model_foundation) - model_config["atomic_energies"] = atomic_energies - model_config["atomic_numbers"] = z_table.zs - model_config["num_elements"] = len(z_table) - args.max_L = model_config["hidden_irreps"].lmax + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + args.max_L = model_config_foundation["hidden_irreps"].lmax if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": - model_config["atomic_inter_shift"] = [0.0] * len(heads) + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) else: - model_config["atomic_inter_shift"] = [args.mean] * len(heads) - model_config["atomic_inter_scale"] = [1.0] * len(heads) + model_config_foundation["atomic_inter_shift"] = [args.mean] * len(heads) + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" - model_config["heads"] = args.heads + model_config_foundation["heads"] = args.heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") else: diff --git a/mace/modules/models.py b/mace/modules/models.py index f4154a40..6dd3348f 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -191,7 +191,6 @@ def forward( # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) - print("head", data["head"]) num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 node_heads = data["head"][data["batch"]] @@ -218,7 +217,6 @@ def forward( node_e0 = self.atomic_energies_fn(data["node_attrs"])[ num_atoms_arange, node_heads ] - # print("node e0", node_e0.shape) e0 = scatter_sum( src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs ) # [n_graphs, n_heads] @@ -245,8 +243,6 @@ def forward( pair_energy = torch.zeros_like(e0) # Interactions - print("pair_energy", pair_energy) - print("pair_node_energy", pair_node_energy) energies = [e0, pair_energy] node_energies_list = [node_e0, pair_node_energy] node_feats_list = [] @@ -277,7 +273,6 @@ def forward( ) # [n_graphs,] energies.append(energy) node_energies_list.append(node_energies) - print("node_energies", node_energies) # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) @@ -409,12 +404,10 @@ def forward( # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) - # print("node_es_list", node_es_list) # Sum over interactions node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - # print("node_inter_es", node_inter_es.shape) node_inter_es = self.scale_shift(node_inter_es, node_heads) # Sum over nodes in graph diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index aa367709..24810ba1 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -294,7 +294,7 @@ def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: return model_load_yaml.to(map_location) -def get_atomic_energies(E0s, train_collection, z_table) -> dict: +def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: if E0s is not None: logging.info( "Atomic Energies not in training file, using command line argument E0s" From e565b2cf0b77849a7b7decbb51ce6b8bd361075f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:23:40 +0100 Subject: [PATCH 067/101] bump version number --- mace/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/__version__.py b/mace/__version__.py index d7b30e12..8879c6c7 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1 +1 @@ -__version__ = "0.3.6" +__version__ = "0.3.7" From 387f7df419ebd2d66cb533c779a51d6651f7ad94 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:25:15 +0100 Subject: [PATCH 068/101] formatting --- mace/cli/active_learning_md.py | 1 - mace/cli/eval_configs.py | 1 - mace/cli/fine_tuning_select.py | 1 - mace/cli/run_train.py | 14 +++----- mace/data/utils.py | 12 +++---- mace/modules/utils.py | 1 - mace/tools/arg_parser.py | 28 +++++++++++----- mace/tools/finetuning_utils.py | 58 +++++++++++++++++----------------- mace/tools/scripts_utils.py | 18 +++++------ mace/tools/torch_tools.py | 1 - tests/test_foundations.py | 5 +-- 11 files changed, 69 insertions(+), 71 deletions(-) diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index 52e3879e..648a30b2 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -144,7 +144,6 @@ def main() -> None: def run(args: argparse.Namespace) -> None: - mace_fname = args.model atoms_fname = args.config atoms_index = args.config_index diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index bf53ef88..b5700bc4 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -62,7 +62,6 @@ def main() -> None: def run(args: argparse.Namespace) -> None: - torch_tools.set_default_dtype(args.default_dtype) device = torch_tools.init_device(args.device) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 8bc4abd9..a638b588 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -6,7 +6,6 @@ import logging from typing import List - import ase.data import ase.io import numpy as np diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index f9bd80e1..121eba12 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -10,11 +10,10 @@ import json import logging import os +import urllib.request from copy import deepcopy from pathlib import Path from typing import Optional -import urllib.request - import numpy as np import torch.distributed @@ -30,26 +29,21 @@ from mace.calculators.foundations_models import mace_mp, mace_off from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations +from mace.tools.finetuning_utils import load_foundations, load_foundations_elements from mace.tools.scripts_utils import ( LRScheduler, + check_folder_subfolder, convert_to_json_format, create_error_table, + dict_to_array, dict_to_namespace, extract_config_mace_model, get_atomic_energies, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, - dict_to_array, - check_folder_subfolder, print_git_commit, ) -from mace.tools.slurm_distributed import DistributedEnvironment -from mace.tools.finetuning_utils import ( - load_foundations_elements, -) - from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable diff --git a/mace/data/utils.py b/mace/data/utils.py index 31680dee..b93c2514 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -274,9 +274,9 @@ def load_from_xyz( head = atoms.info.get(head_key, "Default") if head not in atomic_energies_dict: atomic_energies_dict[head] = {} - atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( - atoms.info[energy_key] - ) + atomic_energies_dict[head][ + atoms.get_atomic_numbers()[0] + ] = atoms.info[energy_key] else: logging.warning( f"Configuration '{idx}' is marked as 'IsolatedAtom' " @@ -285,9 +285,9 @@ def load_from_xyz( head = atoms.info.get(head_key, "Default") if head not in atomic_energies_dict: atomic_energies_dict[head] = {} - atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( - np.zeros(1) - ) + atomic_energies_dict[head][ + atoms.get_atomic_numbers()[0] + ] = np.zeros(1) else: atoms_without_iso_atoms.append(atoms) diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 92bfd952..36921b11 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -144,7 +144,6 @@ def compute_hessians_loop( forces: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: - hessian = [] for grad_elem in forces.view(-1): hess_row = torch.autograd.grad( diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 3e59fe37..a4af0a2e 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -420,7 +420,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--forces_weight", help="weight of forces loss", type=float, default=100.0 ) parser.add_argument( - "--swa_forces_weight","--stage_two_forces_weight", + "--swa_forces_weight", + "--stage_two_forces_weight", help="weight of forces loss after starting Stage Two (previously called swa)", type=float, default=100.0, @@ -430,7 +431,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--energy_weight", help="weight of energy loss", type=float, default=1.0 ) parser.add_argument( - "--swa_energy_weight","--stage_two_energy_weight", + "--swa_energy_weight", + "--stage_two_energy_weight", help="weight of energy loss after starting Stage Two (previously called swa)", type=float, default=1000.0, @@ -440,7 +442,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--virials_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_virials_weight", "--stage_two_virials_weight", + "--swa_virials_weight", + "--stage_two_virials_weight", help="weight of virials loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -450,7 +453,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--stress_weight", help="weight of virials loss", type=float, default=1.0 ) parser.add_argument( - "--swa_stress_weight", "--stage_two_stress_weight", + "--swa_stress_weight", + "--stage_two_stress_weight", help="weight of stress loss after starting Stage Two (previously called swa)", type=float, default=10.0, @@ -460,7 +464,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 ) parser.add_argument( - "--swa_dipole_weight","--stage_two_dipole_weight", + "--swa_dipole_weight", + "--stage_two_dipole_weight", help="weight of dipoles after starting Stage Two (previously called swa)", type=float, default=1.0, @@ -499,7 +504,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--lr", help="Learning rate of optimizer", type=float, default=0.01 ) parser.add_argument( - "--swa_lr", "--stage_two_lr", help="Learning rate of optimizer in Stage Two (previously called swa)", type=float, default=1e-3, dest="swa_lr" + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", ) parser.add_argument( "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 @@ -526,14 +536,16 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=0.9993, ) parser.add_argument( - "--swa", "--stage_two", + "--swa", + "--stage_two", help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", action="store_true", default=False, dest="swa", ) parser.add_argument( - "--start_swa","--start_stage_two", + "--start_swa", + "--start_stage_two", help="Number of epochs before changing to Stage Two loss weights", type=int, default=None, diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 0d4e2f52..ab5bf139 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -50,24 +50,24 @@ def load_foundations_elements( for j in range(4): # Assuming 4 layers in conv_tp_weights, layer_name = f"layer{j}" if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, ) + .weight[:num_radial, :] + .clone() ) else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) + getattr( + model.interactions[i].conv_tp_weights, layer_name + ).weight = torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() ) model.interactions[i].linear.weight = torch.nn.Parameter( @@ -106,23 +106,23 @@ def load_foundations_elements( for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[ + j + ].weights_max = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() ) for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) + model.products[i].symmetric_contractions.contractions[j].weights[ + k + ] = torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() ) model.products[i].linear.weight = torch.nn.Parameter( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 24810ba1..40a24f31 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -237,9 +237,9 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls"] == "" ): - dict_output["interaction_cls"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) + dict_output[ + "interaction_cls" + ] = modules.blocks.RealAgnosticResidualInteractionBlock if ( dict_input["interaction_cls"] == "" @@ -249,16 +249,16 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls_first"] == "" ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticResidualInteractionBlock if ( dict_input["interaction_cls_first"] == "" ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticInteractionBlock - ) + dict_output[ + "interaction_cls_first" + ] = modules.blocks.RealAgnosticInteractionBlock dict_output["r_max"] = float(dict_input["r_max"]) dict_output["num_bessel"] = int(dict_input["num_bessel"]) dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index f3e471c1..e42a74f8 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -68,7 +68,6 @@ def init_device(device_str: str) -> torch.device: torch.xpu.is_available() return torch.device("xpu") - logging.info("Using CPU") return torch.device("cpu") diff --git a/tests/test_foundations.py b/tests/test_foundations.py index f425ab81..975d0586 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -9,12 +9,9 @@ from mace import data, modules, tools from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.utils import ( - AtomicNumberTable, -) from mace.tools.finetuning_utils import ( - load_foundations_elements, extract_config_mace_model, + load_foundations_elements, ) from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable From 7c52013b1c42e91f70c71f127d50c6006ecee54e Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:22:26 +0100 Subject: [PATCH 069/101] start fix linter --- mace/__version__.py | 2 ++ mace/calculators/mace.py | 2 +- mace/cli/fine_tuning_select.py | 27 ++++++++++++--------------- mace/tools/scripts_utils.py | 20 ++++++++++---------- pyproject.toml | 1 + 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/mace/__version__.py b/mace/__version__.py index 8879c6c7..47e8e016 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1 +1,3 @@ __version__ = "0.3.7" + +__all__ = ["__version__"] diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index c4962576..6140f715 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -147,7 +147,7 @@ def __init__( self.charges_key = charges_key try: self.heads = self.models[0].heads - except: + except AttributeError: self.heads = ["Default"] model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index a638b588..169e9a5f 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -94,9 +94,7 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def calculate_descriptors( - atoms: List[ase.Atoms], calc: MACECalculator, cutoffs: None -) -> None: +def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: logging.info("Calculating descriptors") for mol in atoms: descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) @@ -128,23 +126,22 @@ def filter_atoms( """ if filtering_type == "none": return True - elif filtering_type == "combinations": + if filtering_type == "combinations": atom_symbols = np.unique(atoms.symbols) return all( - [x in element_subset for x in atom_symbols] + x in element_subset for x in atom_symbols ) # atoms must *only* contain elements in the subset - elif filtering_type == "exclusive": - atom_symbols = set([x for x in atoms.symbols]) + if filtering_type == "exclusive": + atom_symbols = set(list(atoms.symbols)) return atom_symbols == set(element_subset) - elif filtering_type == "inclusive": + if filtering_type == "inclusive": atom_symbols = np.unique(atoms.symbols) return all( - [x in atom_symbols for x in element_subset] + x in atom_symbols for x in element_subset ) # atoms must *at least* contain elements in the subset - else: - raise ValueError( - f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." - ) + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." + ) class FPS: @@ -206,7 +203,7 @@ def select_samples( ) atoms_list_ft = ase.io.read(args.configs_ft, index=":") - if args.filtering_type != None: + if args.filtering_type is not None: all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) logging.info( "Filtering configurations based on the finetuning set, " @@ -265,7 +262,7 @@ def select_samples( if args.num_samples is not None and args.num_samples < len(atoms_list_pt): if args.descriptors is None: logging.info("Calculating descriptors for the pretraining set") - calculate_descriptors(atoms_list_pt, calc, None) + calculate_descriptors(atoms_list_pt, calc) descriptors_list = [ atoms.info["mace_descriptors"] for atoms in atoms_list_pt ] diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 40a24f31..e9c7c7d8 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -31,7 +31,7 @@ class SubsetCollection: def get_dataset_from_xyz( train_path: str, - valid_path: str, + valid_path: Optional[str], valid_fraction: float, config_type_weights: Dict, test_path: str = None, @@ -237,9 +237,9 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls"] == "" ): - dict_output[ - "interaction_cls" - ] = modules.blocks.RealAgnosticResidualInteractionBlock + dict_output["interaction_cls"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) if ( dict_input["interaction_cls"] == "" @@ -249,16 +249,16 @@ def convert_from_json_format(dict_input): dict_input["interaction_cls_first"] == "" ): - dict_output[ - "interaction_cls_first" - ] = modules.blocks.RealAgnosticResidualInteractionBlock + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) if ( dict_input["interaction_cls_first"] == "" ): - dict_output[ - "interaction_cls_first" - ] = modules.blocks.RealAgnosticInteractionBlock + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticInteractionBlock + ) dict_output["r_max"] = float(dict_input["r_max"]) dict_output["num_bessel"] = int(dict_input["num_bessel"]) dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) diff --git a/pyproject.toml b/pyproject.toml index 05037dc7..489bc6e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ disable = [ "not-callable", "logging-fstring-interpolation", "logging-not-lazy", + "logging-too-many-args", "invalid-name", "too-few-public-methods", "too-many-instance-attributes", From 4a1beb9ac2c024fa60f12f7e242deea353f5ac64 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:51:41 +0100 Subject: [PATCH 070/101] linter --- mace/cli/fine_tuning_select.py | 50 ++++++++++++++++------------------ 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 169e9a5f..f06878c0 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -160,10 +160,16 @@ def run( """ Run the farthest point sampling algorithm. """ - logging.info(self.descriptors_dataset.reshape(len(self.atoms_list), -1).shape) - logging.info("n_samples", self.n_samples) + descriptor_dataset_reshaped = ( + self.descriptors_dataset.reshape( # pylint: disable=E1121 + (len(self.atoms_list), -1) + ) + ) + logging.info(f"{descriptor_dataset_reshaped.shape}") + logging.info(f"n_samples: {self.n_samples}") self.list_index = fpsample.fps_npdu_kdtree_sampling( - self.descriptors_dataset.reshape(len(self.atoms_list), -1), self.n_samples + descriptor_dataset_reshaped, + self.n_samples, ) return self.list_index @@ -171,23 +177,19 @@ def assemble_descriptors(self) -> np.ndarray: """ Assemble the descriptors for all the configurations. """ - self.descriptors_dataset = np.float32( - 10e10 - * np.ones( - ( - len(self.atoms_list), - len(self.species), - len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), - ), - dtype=np.float32, - ) - ) + self.descriptors_dataset: np.ndarray = 10e10 * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), + ), + dtype=np.float32, + ).astype(np.float32) + for i, atoms in enumerate(self.atoms_list): - descriptors = atoms.info["mace_descriptors"] + descriptors = np.array(atoms.info["mace_descriptors"]).astype(np.float32) for z in descriptors: - self.descriptors_dataset[i, self.species_dict[z]] = np.float32( - descriptors[z] - ) + self.descriptors_dataset[i, self.species_dict[z]] = descriptors[z] def select_samples( @@ -251,9 +253,7 @@ def select_samples( atoms_list_pt = ase.io.read(args.configs_pt, index=":") if args.descriptors is not None: logging.info( - "Loading descriptors for the pretraining set from {}".format( - args.descriptors - ) + f"Loading descriptors for the pretraining set from {args.descriptors}" ) descriptors = np.load(args.descriptors, allow_pickle=True) for i, atoms in enumerate(atoms_list_pt): @@ -267,17 +267,15 @@ def select_samples( atoms.info["mace_descriptors"] for atoms in atoms_list_pt ] logging.info( - "Saving descriptors at {}".format( - args.output.replace(".xyz", "descriptors.npy") - ) + f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" ) - np.save(args.output.replace(".xyz", "descriptors.npy"), descriptors_list) + np.save(args.output.replace(".xyz", "_descriptors.npy"), descriptors_list) logging.info("Selecting configurations using Farthest Point Sampling") try: fps_pt = FPS(atoms_list_pt, args.num_samples) idx_pt = fps_pt.run() logging.info(f"Selected {len(idx_pt)} configurations") - except Exception as e: + except Exception as e: # pylint: disable=W0703 logging.error(f"FPS failed, selecting random configurations instead: {e}") idx_pt = np.random.choice( list(range(len(atoms_list_pt)), args.num_samples, replace=False) From d17ad39e81ecafda161d84369750a0086ec8e38a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:17:58 +0100 Subject: [PATCH 071/101] linter fixes --- mace/cli/preprocess_data.py | 2 +- mace/cli/run_train.py | 132 ++-------------------------- mace/data/atomic_data.py | 6 +- mace/data/utils.py | 12 +-- mace/modules/blocks.py | 4 +- mace/modules/models.py | 4 +- mace/tools/finetuning_utils.py | 58 ++++++------- mace/tools/scripts_utils.py | 151 +++++++++++++++++++++++++-------- mace/tools/utils.py | 2 +- tests/test_foundations.py | 18 ++-- tests/test_models.py | 4 +- 11 files changed, 179 insertions(+), 214 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 8ea345f6..8b2f00f1 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -207,7 +207,7 @@ def run(args: argparse.Namespace): if args.compute_statistics: logging.info("Computing statistics") if len(atomic_energies_dict) == 0: - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table, ["Default"]) atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 121eba12..05a71335 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -21,7 +21,6 @@ from e3nn import o3 from e3nn.util import jit from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.swa_utils import SWALR, AveragedModel from torch_ema import ExponentialMovingAverage import mace @@ -29,7 +28,7 @@ from mace.calculators.foundations_models import mace_mp, mace_off from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations, load_foundations_elements +from mace.tools.finetuning_utils import load_foundations_elements from mace.tools.scripts_utils import ( LRScheduler, check_folder_subfolder, @@ -42,6 +41,8 @@ get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, + get_loss_fn, + get_swa, print_git_commit, ) from mace.tools.slurm_distributed import DistributedEnvironment @@ -65,10 +66,10 @@ def run(args: argparse.Namespace) -> None: if args.device == "xpu": try: import intel_extension_for_pytorch as ipex - except ImportError: + except ImportError as e: raise ImportError( "Error: Intel extension for PyTorch not found, but XPU device was specified" - ) + ) from e if args.distributed: try: distr_env = DistributedEnvironment() @@ -469,65 +470,8 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - # valid_loader = torch_geometric.dataloader.DataLoader( - # dataset=valid_set, - # batch_size=args.valid_batch_size, - # sampler=valid_sampler, - # shuffle=(valid_sampler is None), - # drop_last=False, - # pin_memory=args.pin_memory, - # num_workers=args.num_workers, - # ) - if args.loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) - elif args.loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) - elif args.loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - virials_weight=args.virials_weight, - ) - elif args.loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - ) - elif args.loss == "huber": - loss_fn = modules.WeightedHuberEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "universal": - loss_fn = modules.UniversalLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=args.dipole_weight, - ) - elif args.loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - dipole_weight=args.dipole_weight, - ) - else: - # Unweighted Energy and Forces loss by default - loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + loss_fn = get_loss_fn(args.loss, dipole_only, compute_dipole) logging.info(loss_fn) if args.compute_avg_num_neighbors: @@ -791,69 +735,7 @@ def run(args: argparse.Namespace) -> None: swa: Optional[tools.SWAContainer] = None swas = [False] if args.swa: - assert dipole_only is False, "Stage Two for dipole fitting not implemented" - swas.append(True) - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - else: - if args.start_swa > args.max_num_epochs: - logging.info( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" - ) - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info(f"Setting start Stage Two to {args.start_swa}") - if args.loss == "forces_only": - raise ValueError("Can not select Stage Two with forces only loss.") - if args.loss == "virials": - loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - virials_weight=args.swa_virials_weight, - ) - elif args.loss == "stress": - loss_fn_energy = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - ) - elif args.loss == "energy_forces_dipole": - loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( - args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - dipole_weight=args.swa_dipole_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "universal": - loss_fn_energy = modules.UniversalLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - huber_delta=args.huber_delta, - ) - logging.info( - f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" - ) - else: - loss_fn_energy = modules.WeightedEnergyForcesLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" - ) - swa = tools.SWAContainer( - model=AveragedModel(model), - scheduler=SWALR( - optimizer=optimizer, - swa_lr=args.swa_lr, - anneal_epochs=1, - anneal_strategy="linear", - ), - start=args.start_swa, - loss_fn=loss_fn_energy, - ) + swa, swas = get_swa(args, model, optimizer, swas, dipole_only) checkpoint_handler = tools.CheckpointHandler( directory=args.checkpoints_dir, diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 9344546e..814a23e0 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -115,8 +115,10 @@ def from_config( config: Configuration, z_table: AtomicNumberTable, cutoff: float, - heads: Optional[list] = ["Default"], + heads: Optional[list] = None, ) -> "AtomicData": + if heads is None: + heads = ["default"] edge_index, shifts, unit_shifts = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell ) @@ -127,7 +129,7 @@ def from_config( ) try: head = torch.tensor(heads.index(config.head), dtype=torch.long) - except: + except ValueError: head = torch.tensor(len(heads) - 1, dtype=torch.long) cell = ( diff --git a/mace/data/utils.py b/mace/data/utils.py index b93c2514..31680dee 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -274,9 +274,9 @@ def load_from_xyz( head = atoms.info.get(head_key, "Default") if head not in atomic_energies_dict: atomic_energies_dict[head] = {} - atomic_energies_dict[head][ - atoms.get_atomic_numbers()[0] - ] = atoms.info[energy_key] + atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( + atoms.info[energy_key] + ) else: logging.warning( f"Configuration '{idx}' is marked as 'IsolatedAtom' " @@ -285,9 +285,9 @@ def load_from_xyz( head = atoms.info.get(head_key, "Default") if head not in atomic_energies_dict: atomic_energies_dict[head] = {} - atomic_energies_dict[head][ - atoms.get_atomic_numbers()[0] - ] = np.zeros(1) + atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( + np.zeros(1) + ) else: atoms_without_iso_atoms.append(atoms) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 7a162146..fa63f604 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -52,7 +52,9 @@ def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")) self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) def forward( - self, x: torch.Tensor, heads: Optional[torch.Tensor] = None + self, + x: torch.Tensor, + heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] diff --git a/mace/modules/models.py b/mace/modules/models.py index 6dd3348f..f3481ce0 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -61,7 +61,7 @@ def __init__( distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", - heads: Optional[List[str]] = ["Default"], + heads: Optional[List[str]] = None, ): super().__init__() self.register_buffer( @@ -73,6 +73,8 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) + if heads is None: + heads = ["default"] self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index ab5bf139..0d4e2f52 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -50,24 +50,24 @@ def load_foundations_elements( for j in range(4): # Assuming 4 layers in conv_tp_weights, layer_name = f"layer{j}" if j == 0: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() ) - .weight[:num_radial, :] - .clone() ) else: - getattr( - model.interactions[i].conv_tp_weights, layer_name - ).weight = torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) ) model.interactions[i].linear.weight = torch.nn.Parameter( @@ -106,24 +106,24 @@ def load_foundations_elements( for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[ - j - ].weights_max = torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) - - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[ - k - ] = torch.nn.Parameter( + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( model_foundations.products[i] .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] + .weights_max[indices_weights, :, :] .clone() ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) model.products[i].linear.weight = torch.nn.Parameter( model_foundations.products[i].linear.weight.clone() diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index e9c7c7d8..6053f2ea 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -17,9 +17,11 @@ import torch.distributed from e3nn import o3 from prettytable import PrettyTable +from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules from mace.tools import evaluate +from mace.tools.train import SWAContainer @dataclasses.dataclass @@ -316,7 +318,8 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: else: if E0s.endswith(".json"): logging.info(f"Loading atomic energies from {E0s}") - atomic_energies_dict = json.load(open(E0s, "r")) + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) else: try: atomic_energies_eval = ast.literal_eval(E0s) @@ -340,54 +343,134 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: def get_loss_fn( - loss: str, - energy_weight: float, - forces_weight: float, - stress_weight: float, - virials_weight: float, - dipole_weight: float, + args: argparse.Namespace, dipole_only: bool, compute_dipole: bool, ) -> torch.nn.Module: - if loss == "weighted": + if args.loss == "weighted": loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight + energy_weight=args.energy_weight, forces_weight=args.forces_weight ) - elif loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) - elif loss == "virials": + elif args.loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) + elif args.loss == "virials": loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - virials_weight=virials_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + virials_weight=args.virials_weight, ) - elif loss == "stress": + elif args.loss == "stress": loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, ) - elif loss == "dipole": + elif args.loss == "huber": + loss_fn = modules.WeightedHuberEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "universal": + loss_fn = modules.UniversalLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "dipole": assert ( dipole_only is True ), "dipole loss can only be used with AtomicDipolesMACE model" loss_fn = modules.DipoleSingleLoss( - dipole_weight=dipole_weight, + dipole_weight=args.dipole_weight, ) - elif loss == "energy_forces_dipole": + elif args.loss == "energy_forces_dipole": assert dipole_only is False and compute_dipole is True loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - dipole_weight=dipole_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + dipole_weight=args.dipole_weight, ) else: - loss_fn = modules.EnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight - ) + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) return loss_fn +def get_swa( + args: argparse.Namespace, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + swas: List[bool], + dipole_only: bool = False, +): + assert dipole_only is False, "Stage Two for dipole fitting not implemented" + swas.append(True) + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa > args.max_num_epochs: + logging.info( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + logging.info(f"Setting start Stage Two to {args.start_swa}") + if args.loss == "forces_only": + raise ValueError("Can not select Stage Two with forces only loss.") + if args.loss == "virials": + loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + virials_weight=args.swa_virials_weight, + ) + elif args.loss == "stress": + loss_fn_energy = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + ) + elif args.loss == "energy_forces_dipole": + loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( + args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + dipole_weight=args.swa_dipole_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + else: + loss_fn_energy = modules.WeightedEnergyForcesLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + ) + swa = SWAContainer( + model=AveragedModel(model), + scheduler=SWALR( + optimizer=optimizer, + swa_lr=args.swa_lr, + anneal_epochs=1, + anneal_strategy="linear", + ), + start=args.start_swa, + loss_fn=loss_fn_energy, + ) + return swa, swas + + def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: return [ os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) @@ -407,16 +490,16 @@ def custom_key(key): return (2, key) -def dict_to_array(data, heads): - if not all(isinstance(value, dict) for value in data.values()): - return np.array(list(data.values())) +def dict_to_array(input_data, heads): + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array(list(input_data.values())) unique_keys = set() - for inner_dict in data.values(): + for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) unique_keys = list(unique_keys) sorted_keys = sorted([int(key) for key in unique_keys]) - result_array = np.zeros((len(data), len(sorted_keys))) - for _, (head_name, inner_dict) in enumerate(data.items()): + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): for key, value in inner_dict.items(): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index c33b7b3b..65190108 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -8,7 +8,7 @@ import logging import os import sys -from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +from typing import Any, Dict, Iterable, Optional, Sequence, Union import numpy as np import torch diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 975d0586..b1724629 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -9,10 +9,7 @@ from mace import data, modules, tools from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import ( - extract_config_mace_model, - load_foundations_elements, -) +from mace.tools.finetuning_utils import load_foundations_elements from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable @@ -100,7 +97,7 @@ def test_foundations(): def test_multi_reference(): - config = data.Configuration( + config_multi = data.Configuration( atomic_numbers=molecule("H2COH").numbers, positions=molecule("H2COH").positions, forces=molecule("H2COH").positions, @@ -109,8 +106,8 @@ def test_multi_reference(): dipole=np.array([-1.5, 1.5, 2.0]), head="MP2", ) - table = tools.AtomicNumberTable([1, 6, 8]) - atomic_energies = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + table_multi = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) # Create MACE model model_config = dict( @@ -129,7 +126,7 @@ def test_multi_reference(): hidden_irreps=o3.Irreps("128x0e + 128x1o"), MLP_irreps=o3.Irreps("16x0e"), gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, + atomic_energies=atomic_energies_multi, avg_num_neighbors=61, atomic_numbers=table.zs, correlation=3, @@ -149,7 +146,7 @@ def test_multi_reference(): max_L=1, ) atomic_data = data.AtomicData.from_config( - config, z_table=table, cutoff=6.0, heads=["MP2", "DFT"] + config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] ) data_loader = torch_geometric.dataloader.DataLoader( dataset=[atomic_data, atomic_data], @@ -185,9 +182,6 @@ def test_extract_config(model): model_copy.load_state_dict(model.state_dict()) z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=z_table, cutoff=6.0 - ) data_loader = torch_geometric.dataloader.DataLoader( dataset=[atomic_data, atomic_data], batch_size=2, diff --git a/tests/test_models.py b/tests/test_models.py index ca3e69b5..8e8c60da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -199,7 +199,7 @@ def test_energy_dipole_mace(): def test_mace_multi_reference(): - atomic_energies = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) model_config = dict( r_max=5, num_bessel=8, @@ -216,7 +216,7 @@ def test_mace_multi_reference(): hidden_irreps=o3.Irreps("96x0e + 96x1o"), MLP_irreps=o3.Irreps("16x0e"), gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, + atomic_energies=atomic_energies_multi, avg_num_neighbors=8, atomic_numbers=table.zs, distance_transform=True, From f2bf7a1764f85d5a096216a4c6e0446c28d1d1e0 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 21 Aug 2024 18:14:52 +0100 Subject: [PATCH 072/101] start new parsing of head --- mace/calculators/mace.py | 6 +- mace/cli/fine_tuning_select.py | 63 ++++--- mace/cli/run_train.py | 313 +++++++++++++++++---------------- mace/data/utils.py | 134 ++++++-------- mace/modules/blocks.py | 4 +- mace/modules/models.py | 12 +- mace/tools/arg_parser.py | 37 ++-- mace/tools/multihead_tools.py | 66 +++++++ mace/tools/scripts_utils.py | 41 ++--- tests/test_calculator.py | 9 +- tests/test_run_train.py | 58 +++--- 11 files changed, 419 insertions(+), 324 deletions(-) create mode 100644 mace/tools/multihead_tools.py diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 6140f715..8afa034e 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -235,8 +235,12 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = self._clone_batch(batch_base) + print("head", batch["head"]) node_heads = batch["head"][batch["batch"]] - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"], node_heads) + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_heads + ] compute_stress = not self.use_compile else: compute_stress = False diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index f06878c0..373095c9 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -38,6 +38,13 @@ def parse_args() -> argparse.Namespace: required=False, default=None, ) + parser.add_argument( + "--subselect", + help="method to subselect the configurations of the pretraining set", + type=str, + choices=["fps", "random"], + default="fps", + ) parser.add_argument( "--model", help="path to model", default="small", required=False ) @@ -187,9 +194,11 @@ def assemble_descriptors(self) -> np.ndarray: ).astype(np.float32) for i, atoms in enumerate(self.atoms_list): - descriptors = np.array(atoms.info["mace_descriptors"]).astype(np.float32) + descriptors = atoms.info["mace_descriptors"] for z in descriptors: - self.descriptors_dataset[i, self.species_dict[z]] = descriptors[z] + self.descriptors_dataset[i, self.species_dict[z]] = np.array( + descriptors[z] + ).astype(np.float32) def select_samples( @@ -260,27 +269,39 @@ def select_samples( atoms.info["mace_descriptors"] = descriptors[i] if args.num_samples is not None and args.num_samples < len(atoms_list_pt): - if args.descriptors is None: - logging.info("Calculating descriptors for the pretraining set") - calculate_descriptors(atoms_list_pt, calc) - descriptors_list = [ - atoms.info["mace_descriptors"] for atoms in atoms_list_pt - ] - logging.info( - f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" - ) - np.save(args.output.replace(".xyz", "_descriptors.npy"), descriptors_list) - logging.info("Selecting configurations using Farthest Point Sampling") - try: - fps_pt = FPS(atoms_list_pt, args.num_samples) - idx_pt = fps_pt.run() - logging.info(f"Selected {len(idx_pt)} configurations") - except Exception as e: # pylint: disable=W0703 - logging.error(f"FPS failed, selecting random configurations instead: {e}") + if args.subselect == "fps": + if args.descriptors is None: + logging.info("Calculating descriptors for the pretraining set") + calculate_descriptors(atoms_list_pt, calc) + descriptors_list = [ + atoms.info["mace_descriptors"] for atoms in atoms_list_pt + ] + logging.info( + f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" + ) + np.save( + args.output.replace(".xyz", "_descriptors.npy"), descriptors_list + ) + logging.info("Selecting configurations using Farthest Point Sampling") + try: + fps_pt = FPS(atoms_list_pt, args.num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + except Exception as e: # pylint: disable=W0703 + logging.error( + f"FPS failed, selecting random configurations instead: {e}" + ) + idx_pt = np.random.choice( + list(range(len(atoms_list_pt))), args.num_samples, replace=False + ) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + else: + logging.info("Selecting random configurations") idx_pt = np.random.choice( - list(range(len(atoms_list_pt)), args.num_samples, replace=False) + list(range(len(atoms_list_pt))), args.num_samples, replace=False ) - atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + print("idx_pt", idx_pt) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 05a71335..146374d3 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -29,6 +29,7 @@ from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.multihead_tools import dict_head_to_dataclass, prepare_default_head from mace.tools.scripts_utils import ( LRScheduler, check_folder_subfolder, @@ -140,158 +141,55 @@ def run(args: argparse.Namespace) -> None: else: args.multiheads_finetuning = False - if args.statistics_file is not None: - with open(args.statistics_file, "r") as f: # pylint: disable=W1514 - statistics = json.load(f) - logging.info("Using statistics json file") - args.r_max = ( - statistics["r_max"] if args.foundation_model is None else args.r_max - ) - args.atomic_numbers = statistics["atomic_numbers"] - args.mean = statistics["mean"] - args.std = statistics["std"] - args.avg_num_neighbors = statistics["avg_num_neighbors"] - args.compute_avg_num_neighbors = False - args.E0s = ( - statistics["atomic_energies"] - if not args.E0s.endswith(".json") - else args.E0s - ) - - # Data preparation - if args.train_file.endswith(".xyz"): - if args.valid_file is not None: - assert args.valid_file.endswith( - ".xyz" - ), "valid_file if given must be same format as train_file" - config_type_weights = get_config_type_weights(args.config_type_weights) - collections, atomic_energies_dict, heads = get_dataset_from_xyz( - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, - keep_isolated_atoms=args.keep_isolated_atoms, - ) - if args.heads is not None: - args.heads = ast.literal_eval(args.heads) - assert set(heads) == set(args.heads), ( - "heads from command line and data do not match," - f"{set(heads)} != {set(args.heads)}" - ) - logging.info( - "Using heads from command line argument," f" heads used: {args.heads}" + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + else: + args.heads = prepare_default_head(args) + for head, head_args in args.heads.items(): + logging.info(f"============= Processing head {head} ===========") + head_config = dict_head_to_dataclass(head_args) + if head_config.statistics_file is not None: + with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 + statistics = json.load(f) + logging.info("Using statistics json file") + head_config.r_max = ( + statistics["r_max"] if args.foundation_model is None else args.r_max ) - heads = args.heads - else: - logging.info( - "Using heads extracted from data files," f" heads used: {heads}" + head_config.atomic_numbers = statistics["atomic_numbers"] + head_config.mean = statistics["mean"] + head_config.std = statistics["std"] + head_config.avg_num_neighbors = statistics["avg_num_neighbors"] + head_config.compute_avg_num_neighbors = False + head_config.E0s = ( + statistics["atomic_energies"] + if not head_config.E0s.endswith(".json") + else head_config.E0s ) - if args.multiheads_finetuning: - logging.info("Using multiheads finetuning mode") - args.loss = "universal" - if heads is not None: - heads = list(dict.fromkeys(["pbe_mp"] + heads)) - args.heads = heads - else: - heads = ["pbe_mp", "Default"] - args.heads = heads - logging.info(f"Using heads: {heads}") - try: - checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" - descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c - for c in os.path.basename(checkpoint_url) - if c.isalnum() or c in "_" - ) - cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" - descriptors_url_name = "".join( - c - for c in os.path.basename(descriptors_url) - if c.isalnum() or c in "_" - ) - cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" - if not os.path.isfile(cached_dataset_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP structures for finetuning") - _, http_msg = urllib.request.urlretrieve( - checkpoint_url, cached_dataset_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Dataset download failed, please check the URL {checkpoint_url}" - ) - logging.info(f"Materials Project dataset to {cached_dataset_path}") - if not os.path.isfile(cached_descriptors_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP descriptors for finetuning") - _, http_msg = urllib.request.urlretrieve( - descriptors_url, cached_descriptors_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Descriptors download failed, please check the URL {descriptors_url}" - ) - logging.info( - f"Materials Project descriptors to {cached_descriptors_path}" - ) - dataset_mp = cached_dataset_path - descriptors_mp = cached_descriptors_path - msg = f"Using Materials Project dataset with {dataset_mp}" - logging.info(msg) - msg = f"Using Materials Project descriptors with {descriptors_mp}" - logging.info(msg) - args_samples = { - "configs_pt": dataset_mp, - "configs_ft": args.train_file, - "num_samples": args.num_samples_pt, - "seed": args.seed, - "model": args.foundation_model, - "head_pt": "pbe_mp", - "head_ft": "Default", - "weight_pt": args.weight_pt_head, - "weight_ft": 1.0, - "filtering_type": "combination", - "output": f"mp_finetuning-{tag}.xyz", - "descriptors": descriptors_mp, - "device": args.device, - "default_dtype": args.default_dtype, - } - select_samples(dict_to_namespace(args_samples)) - collections_mp, _, _ = get_dataset_from_xyz( - train_path=f"mp_finetuning-{tag}.xyz", - valid_path=None, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=None, - seed=args.seed, - energy_key="energy", - forces_key="forces", - stress_key="stress", - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, - keep_isolated_atoms=args.keep_isolated_atoms, - ) - collections.train += collections_mp.train - collections.valid += collections_mp.valid - except Exception as exc: - raise RuntimeError( - "Model download failed and no local model found" - ) from exc - + # Data preparation + if head_config.train_file.endswith(".xyz"): + if head_config.valid_file is not None: + assert head_config.valid_file.endswith( + ".xyz" + ), "valid_file if given must be same format as train_file" + config_type_weights = get_config_type_weights( + head_config.config_type_weights + ) + collections, atomic_energies_dict = get_dataset_from_xyz( + train_path=head_config.train_file, + valid_path=head_config.valid_file, + valid_fraction=head_config.valid_fraction, + config_type_weights=config_type_weights, + test_path=head_config.test_file, + seed=args.seed, + energy_key=head_config.energy_key, + forces_key=head_config.forces_key, + stress_key=head_config.stress_key, + virials_key=head_config.virials_key, + dipole_key=head_config.dipole_key, + charges_key=head_config.charges_key, + keep_isolated_atoms=head_config.keep_isolated_atoms, + ) logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," @@ -299,6 +197,102 @@ def run(args: argparse.Namespace) -> None: else: atomic_energies_dict = None + if args.multiheads_finetuning: + logging.info( + "==================Using multiheads finetuning mode==================" + ) + args.loss = "universal" + if heads is not None: + heads = list(dict.fromkeys(["pbe_mp"] + heads)) + args.heads = heads + else: + heads = ["pbe_mp", "Default"] + args.heads = heads + logging.info(f"Using heads: {heads}") + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c for c in os.path.basename(descriptors_url) if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info( + f"Materials Project descriptors to {cached_descriptors_path}" + ) + dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path + msg = f"Using Materials Project dataset with {dataset_mp}" + logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) + args_samples = { + "configs_pt": dataset_mp, + "configs_ft": args.train_file, + "num_samples": args.num_samples_pt, + "seed": args.seed, + "model": args.foundation_model, + "head_pt": "pbe_mp", + "head_ft": "Default", + "weight_pt": args.weight_pt_head, + "weight_ft": 1.0, + "filtering_type": "combination", + "output": f"mp_finetuning-{tag}.xyz", + "descriptors": descriptors_mp, + "subselect": args.subselect_pt, + "device": args.device, + "default_dtype": args.default_dtype, + } + select_samples(dict_to_namespace(args_samples)) + collections_mp, _, _ = get_dataset_from_xyz( + train_path=f"mp_finetuning-{tag}.xyz", + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=None, + seed=args.seed, + energy_key="energy", + forces_key="forces", + stress_key="stress", + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + ) + collections.train += collections_mp.train + collections.valid += collections_mp.valid + except Exception as exc: + raise RuntimeError( + "Model download failed and no local model found" + ) from exc + # Atomic number table # yapf: disable if args.atomic_numbers is None: @@ -471,7 +465,7 @@ def run(args: argparse.Namespace) -> None: generator=torch.Generator().manual_seed(args.seed), ) - loss_fn = get_loss_fn(args.loss, dipole_only, compute_dipole) + loss_fn = get_loss_fn(args, dipole_only, compute_dipole) logging.info(loss_fn) if args.compute_avg_num_neighbors: @@ -522,7 +516,24 @@ def run(args: argparse.Namespace) -> None: if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) else: - model_config_foundation["atomic_inter_shift"] = [args.mean] * len(heads) + if isinstance(args.mean, np.ndarray): + if args.mean.size == 1: + model_config_foundation["atomic_inter_shift"] = args.mean.item() + elif args.mean.size == len(heads): + model_config_foundation["atomic_inter_shift"] = args.mean.tolist() + else: + logging.info( + "Mean not in correct format, using default value of 0.0" + ) + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + elif isinstance(args.mean, list) and len(args.mean) == len(heads): + model_config_foundation["atomic_inter_shift"] = args.mean + elif isinstance(args.mean, float): + model_config_foundation["atomic_inter_shift"] = [args.mean] * len(heads) + else: + logging.info("Mean not in correct format, using default value of 0.0") + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) args.model = "FoundationMACE" model_config_foundation["heads"] = args.heads diff --git a/mace/data/utils.py b/mace/data/utils.py index 31680dee..b9b2c950 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -206,88 +206,64 @@ def load_from_xyz( virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", - head_key: str = "head", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: atoms_list = ase.io.read(file_path, index=":") - - energy_from_calc = False - forces_from_calc = False - stress_from_calc = False - - # Perform initial checks and log warnings if energy_key == "energy": logging.info( - "Using energy_key 'energy' is unsafe, consider using a different key, rewriting energies to '_REF_energy'" - ) - energy_from_calc = True - energy_key = "_REF_energy" - - if forces_key == "forces": - logging.info( - "Using forces_key 'forces' is unsafe, consider using a different key, rewriting forces to '_REF_forces'" + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_energy'. You need to use --energy_key='REF_energy', to tell the key name chosen." ) - forces_from_calc = True - forces_key = "_REF_forces" - - if stress_key == "stress": - logging.info( - "Using stress_key 'stress' is unsafe, consider using a different key, rewriting stress to '_REF_stress'" - ) - stress_from_calc = True - stress_key = "_REF_stress" - - for atoms in atoms_list: - if energy_from_calc: + energy_key = "REF_energy" + for atoms in atoms_list: try: - atoms.info["_REF_energy"] = atoms.get_potential_energy() + atoms.info["REF_energy"] = atoms.get_potential_energy() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract energy: {e}") - atoms.info["_REF_energy"] = None - - if forces_from_calc: + atoms.info["REF_energy"] = None + if forces_key == "forces": + logging.info( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_forces'. You need to use --forces_key='REF_forces', to tell the key name chosen." + ) + forces_key = "REF_forces" + for atoms in atoms_list: try: - atoms.arrays["_REF_forces"] = atoms.get_forces() + atoms.arrays["REF_forces"] = atoms.get_forces() except Exception as e: # pylint: disable=W0703 logging.warning(f"Failed to extract forces: {e}") - atoms.arrays["_REF_forces"] = None - - if stress_from_calc: + atoms.arrays["REF_forces"] = None + if stress_key == "stress": + logging.info( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting energies to 'REF_stress'. You need to use --stress_key='REF_stress', to tell the key name chosen." + ) + stress_key = "REF_stress" + for atoms in atoms_list: try: - atoms.info["_REF_stress"] = atoms.get_stress() + atoms.info["REF_stress"] = atoms.get_stress() except Exception as e: # pylint: disable=W0703 - atoms.info["_REF_stress"] = None - + atoms.info["REF_stress"] = None if not isinstance(atoms_list, list): atoms_list = [atoms_list] + atomic_energies_dict = {} if extract_atomic_energies: atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - if atoms.info.get("config_type") == "IsolatedAtom": - assert ( - len(atoms) == 1 - ), f"Got config_type=IsolatedAtom for a config with len {len(atoms)}" + isolated_atom_config = ( + len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" + ) + if isolated_atom_config: if energy_key in atoms.info.keys(): - head = atoms.info.get(head_key, "Default") - if head not in atomic_energies_dict: - atomic_energies_dict[head] = {} - atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( - atoms.info[energy_key] - ) + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ + energy_key + ] else: logging.warning( f"Configuration '{idx}' is marked as 'IsolatedAtom' " "but does not contain an energy. Zero energy will be used." ) - head = atoms.info.get(head_key, "Default") - if head not in atomic_energies_dict: - atomic_energies_dict[head] = {} - atomic_energies_dict[head][atoms.get_atomic_numbers()[0]] = ( - np.zeros(1) - ) + atomic_energies_dict[atoms.get_atomic_numbers()[0]] = np.zeros(1) else: atoms_without_iso_atoms.append(atoms) @@ -295,10 +271,7 @@ def load_from_xyz( logging.info("Using isolated atom energies from training file") if not keep_isolated_atoms: atoms_list = atoms_without_iso_atoms - heads = set() - for atoms in atoms_list: - heads.add(atoms.info.get(head_key, "Default")) - heads = list(heads) + configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, @@ -308,13 +281,12 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, - head_key=head_key, ) - return atomic_energies_dict, configs, heads + return atomic_energies_dict, configs def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable, heads: List[str] + collections_train: Configurations, z_table: AtomicNumberTable ) -> Dict[int, float]: """ Function to compute the average interaction energy of each chemical element @@ -322,28 +294,24 @@ def compute_average_E0s( """ len_train = len(collections_train) len_zs = len(z_table) - atomic_energies_dict = {} - for head in heads: - A = np.zeros((len_train, len_zs)) - B = np.zeros(len_train) - if head not in atomic_energies_dict: - atomic_energies_dict[head] = {} - for i in range(len_train): - if collections_train[i].head != head: - continue - B[i] = collections_train[i].energy - for j, z in enumerate(z_table.zs): - A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) - try: - E0s = np.linalg.lstsq(A, B, rcond=None)[0] - for i, z in enumerate(z_table.zs): - atomic_energies_dict[head][z] = E0s[i] - except np.linalg.LinAlgError: - logging.warning( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) - for i, z in enumerate(z_table.zs): - atomic_energies_dict[head][z] = 0.0 + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + for i in range(len_train): + B[i] = collections_train[i].energy + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = E0s[i] + except np.linalg.LinAlgError: + logging.warning( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = 0.0 return atomic_energies_dict diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index fa63f604..48a9d22c 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -762,11 +762,11 @@ def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( "scale", - torch.atleast_1d(torch.tensor(scale, dtype=torch.get_default_dtype())), + torch.tensor(scale, dtype=torch.get_default_dtype()), ) self.register_buffer( "shift", - torch.atleast_1d(torch.tensor(shift, dtype=torch.get_default_dtype())), + torch.tensor(shift, dtype=torch.get_default_dtype()), ) def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: diff --git a/mace/modules/models.py b/mace/modules/models.py index f3481ce0..c0d8ab43 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -195,7 +195,11 @@ def forward( data["positions"].requires_grad_(True) num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 - node_heads = data["head"][data["batch"]] + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -338,7 +342,11 @@ def forward( data["node_attrs"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 num_atoms_arange = torch.arange(data["positions"].shape[0]) - node_heads = data["head"][data["batch"]] + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a4af0a2e..107ea0dc 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -221,19 +221,19 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--compute_avg_num_neighbors", help="normalization factor for the message", - type=bool, + type=str2bool, default=True, ) parser.add_argument( "--compute_stress", help="Select True to compute stress", - type=bool, + type=str2bool, default=False, ) parser.add_argument( "--compute_forces", help="Select True to compute forces", - type=bool, + type=str2bool, default=True, ) @@ -273,7 +273,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--multi_processed_test", help="Boolean value for whether the test data was multiprocessed", - type=bool, + type=str2bool, default=False, required=False, ) @@ -287,7 +287,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--pin_memory", help="Pin memory for data loading", default=True, - type=bool, + type=str2bool, ) parser.add_argument( "--atomic_numbers", @@ -327,7 +327,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--foundation_filter_elements", help="Filter element during fine-tuning", - type=bool, + type=str2bool, default=True, required=False, ) @@ -335,13 +335,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--heads", help="List of heads in the training set", type=str, - default=None, + default='["Default"]', required=False, ) parser.add_argument( "--multiheads_finetuning", help="Boolean value for whether the model is multiheaded", - type=bool, + type=str2bool, default=True, ) parser.add_argument( @@ -356,10 +356,16 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=int, default=1000, ) + parser.add_argument( + "--subselect_pt", + help="Method to subselect the configurations of the pretraining set", + choices=["fps", "random"], + default="random", + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", - type=bool, + type=str2bool, default=False, ) parser.add_argument( @@ -794,7 +800,7 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", help="Shuffle the training dataset", - type=bool, + type=str2bool, default=True, ) parser.add_argument( @@ -815,3 +821,14 @@ def check_float_or_none(value: str) -> Optional[float]: f"{value} is an invalid value (float or None)" ) from None return None + + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + elif value.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py new file mode 100644 index 00000000..66681829 --- /dev/null +++ b/mace/tools/multihead_tools.py @@ -0,0 +1,66 @@ +import argparse +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class HeadConfig: + train_file: Optional[str] + valid_file: Optional[str] + test_file: Optional[str] + E0s: Optional[Any] + statistics_file: Optional[str] + valid_fraction: Optional[float] + config_type_weights: Optional[Dict[str, float]] + energy_key: Optional[str] + forces_key: Optional[str] + stress_key: Optional[str] + virials_key: Optional[str] + dipole_key: Optional[str] + charges_key: Optional[str] + keep_isolated_atoms: Optional[bool] + atomic_numbers: Optional[Dict[str, int]] + mean: Optional[float] + std: Optional[float] + avg_num_neighbors: Optional[float] + compute_avg_num_neighbors: Optional[bool] + + +def dict_head_to_dataclass(head: Dict[str, Any]) -> HeadConfig: + return HeadConfig( + train_file=head.get("train_file"), + valid_file=head.get("valid_file"), + test_file=head.get("test_file"), + E0s=head.get("E0s"), + statistics_file=head.get("statistics_file"), + valid_fraction=head.get("valid_fraction"), + config_type_weights=head.get("config_type_weights"), + energy_key=head.get("energy_key"), + forces_key=head.get("forces_key"), + stress_key=head.get("stress_key"), + virials_key=head.get("virials_key"), + dipole_key=head.get("dipole_key"), + charges_key=head.get("charges_key"), + keep_isolated_atoms=head.get("keep_isolated_atoms"), + ) + + +def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: + return { + "Default": { + "train_file": args.train_file, + "valid_file": args.valid_file, + "test_file": args.test_file, + "E0s": args.E0s, + "statistics_file": args.statistics_file, + "valid_fraction": args.valid_fraction, + "config_type_weights": args.config_type_weights, + "energy_key": args.energy_key, + "forces_key": args.forces_key, + "stress_key": args.stress_key, + "virials_key": args.virials_key, + "dipole_key": args.dipole_key, + "charges_key": args.charges_key, + "keep_isolated_atoms": args.keep_isolated_atoms, + } + } diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 6053f2ea..649f5f9f 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -15,9 +15,9 @@ import numpy as np import torch import torch.distributed +from torch.optim.swa_utils import SWALR, AveragedModel from e3nn import o3 from prettytable import PrettyTable -from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules from mace.tools import evaluate @@ -47,7 +47,7 @@ def get_dataset_from_xyz( charges_key: str = "charges", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" - atomic_energies_dict, all_train_configs, heads = data.load_from_xyz( + atomic_energies_dict, all_train_configs = data.load_from_xyz( file_path=train_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -63,7 +63,7 @@ def get_dataset_from_xyz( f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" ) if valid_path is not None: - _, valid_configs, _ = data.load_from_xyz( + _, valid_configs = data.load_from_xyz( file_path=valid_path, config_type_weights=config_type_weights, energy_key=energy_key, @@ -82,25 +82,17 @@ def get_dataset_from_xyz( logging.info( "Using random %s%% of training set for validation", 100 * valid_fraction ) - train_configs, valid_configs = [], [] - for head in heads: - all_train_configs_head = [ - config for config in all_train_configs if config.head == head - ] - train_configs_head, valid_configs_head = data.random_train_valid_split( - all_train_configs_head, valid_fraction, seed - ) - train_configs.extend(train_configs_head) - valid_configs.extend(valid_configs_head) + train_configs, valid_configs = data.random_train_valid_split( + all_train_configs, valid_fraction, seed + ) test_configs = [] if test_path is not None: - _, all_test_configs, _ = data.load_from_xyz( + _, all_test_configs = data.load_from_xyz( file_path=test_path, config_type_weights=config_type_weights, energy_key=energy_key, forces_key=forces_key, - stress_key=stress_key, dipole_key=dipole_key, stress_key=stress_key, virials_key=virials_key, @@ -115,7 +107,6 @@ def get_dataset_from_xyz( return ( SubsetCollection(train=train_configs, valid=valid_configs, tests=test_configs), atomic_energies_dict, - heads, ) @@ -169,6 +160,8 @@ def radial_to_transform(radial): return "Soft" return radial.distance_transform.__class__.__name__ + scale = model.scale_shift.scale + shift = model.scale_shift.shift config = { "r_max": model.r_max.item(), "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), @@ -204,8 +197,8 @@ def radial_to_transform(radial): "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], "pair_repulsion": hasattr(model, "pair_repulsion_fn"), "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": model.scale_shift.scale.item(), - "atomic_inter_shift": model.scale_shift.shift.item(), + "atomic_inter_scale": scale.cpu().numpy(), + "atomic_inter_shift": shift.cpu().numpy(), } return config @@ -490,16 +483,16 @@ def custom_key(key): return (2, key) -def dict_to_array(input_data, heads): - if not all(isinstance(value, dict) for value in input_data.values()): - return np.array(list(input_data.values())) +def dict_to_array(data, heads): + if not all(isinstance(value, dict) for value in data.values()): + return np.array(list(data.values())) unique_keys = set() - for inner_dict in input_data.values(): + for inner_dict in data.values(): unique_keys.update(inner_dict.keys()) unique_keys = list(unique_keys) sorted_keys = sorted([int(key) for key in unique_keys]) - result_array = np.zeros((len(input_data), len(sorted_keys))) - for _, (head_name, inner_dict) in enumerate(input_data.items()): + result_array = np.zeros((len(data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(data.items()): for key, value in inner_dict.items(): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index bc8f5862..47ed95df 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -10,6 +10,7 @@ from ase.atoms import Atoms from ase.calculators.test import gradient_test from ase.constraints import ExpCellFilter +import torch from mace.calculators import mace_mp, mace_off from mace.calculators.mace import MACECalculator @@ -376,12 +377,12 @@ def test_calculator_node_energy(fitting_configs, trained_model): trained_model.calculate(at) node_energies = trained_model.results["node_energy"] batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = ( - trained_model.models[0] - .atomic_energies_fn(batch["node_attrs"]) - .detach() - .numpy() + trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() ) + node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() energy_via_nodes = np.sum(node_energies + node_e0) energy = trained_model.results["energy"] np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index bb1bab4e..9908a323 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -291,6 +291,10 @@ def test_run_train_foundation(tmp_path, fitting_configs): mace_params["default_dtype"] = "float32" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + print("mace_params", mace_params) + # mace_params["num_samples_pt"] = 50 + # mace_params["subselect_pt"] = "random" # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -359,13 +363,13 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): else: c.info["head"] = "MP2" fitting_configs_.append(c) - ase.io.write(tmp_path / "fit.xyz", fitting_configs_) + ase.io.write(tmp_path / "fit_multihead.xyz", fitting_configs_) mace_params = _mace_params.copy() mace_params["valid_fraction"] = 0.1 mace_params["checkpoints_dir"] = str(tmp_path) mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["train_file"] = tmp_path / "fit_multihead.xyz" mace_params["loss"] = "weighted" mace_params["foundation_model"] = "small" mace_params["hidden_irreps"] = "128x0e" @@ -375,6 +379,8 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["heads"] = "['MP2','DFT']" mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -407,29 +413,29 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 + # from a run on 20/08/2024 on commit ref_Es = [ - 1.1737573146820068, - 0.37266889214515686, - 0.3591262996196747, - 0.1222146600484848, - 0.21925662457942963, - 0.30689263343811035, - 0.23039104044437408, - 0.11772646009922028, - 0.2409999519586563, - 0.04042769968509674, - 0.6277227997779846, - 0.13879507780075073, - 0.18997330963611603, - 0.30589431524276733, - 0.34129756689071655, - -0.0034095346927642822, - 0.5614650249481201, - 0.29983872175216675, - 0.3369189500808716, - -0.20579558610916138, - 0.1669044941663742, - 0.119053915143013, + 1.4186015129089355, + 0.6012811660766602, + 1.4759466648101807, + 1.1662801504135132, + 1.117658019065857, + 1.4062559604644775, + 1.4638032913208008, + 0.9065879583358765, + 1.3814517259597778, + 1.2735612392425537, + 1.2472984790802002, + 1.1374807357788086, + 1.4028346538543701, + 1.0139431953430176, + 1.3830922842025757, + 1.0170294046401978, + 1.6741619110107422, + 1.2575324773788452, + 1.2426478862762451, + 1.0206304788589478, + 1.2309682369232178, + 1.135024070739746, ] - assert np.allclose(Es, ref_Es, tol=1e-2) + assert np.allclose(Es, ref_Es, atol=1e-1) From 928064c0a05eaf481aa04bc4c23a2462a157b4cb Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:54:28 +0100 Subject: [PATCH 073/101] fix all tests and parsing yaml --- mace/cli/fine_tuning_select.py | 9 +- mace/cli/preprocess_data.py | 4 +- mace/cli/run_train.py | 435 +++++++++++++++++---------------- mace/data/atomic_data.py | 2 + mace/data/utils.py | 5 + mace/tools/arg_parser.py | 11 +- mace/tools/multihead_tools.py | 185 +++++++++++--- mace/tools/scripts_utils.py | 26 +- tests/test_calculator.py | 2 +- tests/test_run_train.py | 134 +++++++++- 10 files changed, 534 insertions(+), 279 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 373095c9..288e820a 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -28,7 +28,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--configs_ft", - help="path to XYZ configurations for the finetuning", + help="path or list of paths to XYZ configurations for the finetuning", required=True, ) parser.add_argument( @@ -212,7 +212,12 @@ def select_samples( calc = MACECalculator( model_paths=args.model, device=args.device, default_dtype=args.default_dtype ) - atoms_list_ft = ase.io.read(args.configs_ft, index=":") + if isinstance(args.configs_ft, str): + atoms_list_ft = ase.io.read(args.configs_ft, index=":") + else: + atoms_list_ft = [] + for path in args.configs_ft: + atoms_list_ft += ase.io.read(path, index=":") if args.filtering_type is not None: all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 8b2f00f1..5c198ec4 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -154,7 +154,7 @@ def run(args: argparse.Namespace): os.makedirs(args.h5_prefix + sub_dir) # Data preparation - collections, atomic_energies_dict, _ = get_dataset_from_xyz( + collections, atomic_energies_dict = get_dataset_from_xyz( train_path=args.train_file, valid_path=args.valid_file, valid_fraction=args.valid_fraction, @@ -207,7 +207,7 @@ def run(args: argparse.Namespace): if args.compute_statistics: logging.info("Computing statistics") if len(atomic_energies_dict) == 0: - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table, ["Default"]) + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) atomic_energies: np.ndarray = np.array( [atomic_energies_dict[z] for z in z_table.zs] ) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 146374d3..5da11b6d 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -10,10 +10,9 @@ import json import logging import os -import urllib.request from copy import deepcopy from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch.distributed @@ -21,22 +20,25 @@ from e3nn import o3 from e3nn.util import jit from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import ConcatDataset from torch_ema import ExponentialMovingAverage import mace from mace import data, modules, tools from mace.calculators.foundations_models import mace_mp, mace_off -from mace.cli.fine_tuning_select import select_samples from mace.tools import torch_geometric from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.multihead_tools import dict_head_to_dataclass, prepare_default_head +from mace.tools.multihead_tools import ( + HeadConfig, + assemble_mp_data, + dict_head_to_dataclass, + prepare_default_head, +) from mace.tools.scripts_utils import ( LRScheduler, - check_folder_subfolder, convert_to_json_format, create_error_table, dict_to_array, - dict_to_namespace, extract_config_mace_model, get_atomic_energies, get_config_type_weights, @@ -145,9 +147,12 @@ def run(args: argparse.Namespace) -> None: args.heads = ast.literal_eval(args.heads) else: args.heads = prepare_default_head(args) + heads = list(args.heads.keys()) + logging.info(f"Using heads: {heads}") + head_configs: List[HeadConfig] = [] for head, head_args in args.heads.items(): logging.info(f"============= Processing head {head} ===========") - head_config = dict_head_to_dataclass(head_args) + head_config = dict_head_to_dataclass(head_args, head, args) if head_config.statistics_file is not None: with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 statistics = json.load(f) @@ -188,151 +193,133 @@ def run(args: argparse.Namespace) -> None: virials_key=head_config.virials_key, dipole_key=head_config.dipole_key, charges_key=head_config.charges_key, + head_name=head_config.head_name, keep_isolated_atoms=head_config.keep_isolated_atoms, ) + head_config.collections = collections + head_config.atomic_energies_dict = atomic_energies_dict + print("ATOMIC ENERGIES DICT", atomic_energies_dict) + head_configs.append(head_config) logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," ) - else: - atomic_energies_dict = None if args.multiheads_finetuning: logging.info( "==================Using multiheads finetuning mode==================" ) args.loss = "universal" - if heads is not None: - heads = list(dict.fromkeys(["pbe_mp"] + heads)) - args.heads = heads - else: - heads = ["pbe_mp", "Default"] - args.heads = heads - logging.info(f"Using heads: {heads}") - try: - checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" - descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" - descriptors_url_name = "".join( - c for c in os.path.basename(descriptors_url) if c.isalnum() or c in "_" + if ( + args.foundation_model in ["small", "medium", "large"] + or "mp" in args.foundation_model + ): + heads = list(dict.fromkeys(["pt_head"] + heads)) + head_config_pt = HeadConfig( + head_name="pt_head", + E0s="foundation", + statistics_file=args.statistics_file, ) - cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" - if not os.path.isfile(cached_dataset_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP structures for finetuning") - _, http_msg = urllib.request.urlretrieve( - checkpoint_url, cached_dataset_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Dataset download failed, please check the URL {checkpoint_url}" - ) - logging.info(f"Materials Project dataset to {cached_dataset_path}") - if not os.path.isfile(cached_descriptors_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP descriptors for finetuning") - _, http_msg = urllib.request.urlretrieve( - descriptors_url, cached_descriptors_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Descriptors download failed, please check the URL {descriptors_url}" - ) - logging.info( - f"Materials Project descriptors to {cached_descriptors_path}" - ) - dataset_mp = cached_dataset_path - descriptors_mp = cached_descriptors_path - msg = f"Using Materials Project dataset with {dataset_mp}" - logging.info(msg) - msg = f"Using Materials Project descriptors with {descriptors_mp}" - logging.info(msg) - args_samples = { - "configs_pt": dataset_mp, - "configs_ft": args.train_file, - "num_samples": args.num_samples_pt, - "seed": args.seed, - "model": args.foundation_model, - "head_pt": "pbe_mp", - "head_ft": "Default", - "weight_pt": args.weight_pt_head, - "weight_ft": 1.0, - "filtering_type": "combination", - "output": f"mp_finetuning-{tag}.xyz", - "descriptors": descriptors_mp, - "subselect": args.subselect_pt, - "device": args.device, - "default_dtype": args.default_dtype, - } - select_samples(dict_to_namespace(args_samples)) - collections_mp, _, _ = get_dataset_from_xyz( - train_path=f"mp_finetuning-{tag}.xyz", - valid_path=None, + collections = assemble_mp_data(args, tag, head_configs) + head_config_pt.collections = collections + head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" + head_configs.append(head_config_pt) + else: + heads = list(dict.fromkeys(["pt_head"] + heads)) + collections, atomic_energies_dict = get_dataset_from_xyz( + train_path=args.pt_train_file, + valid_path=args.pt_valid_file, valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, + config_type_weights=None, test_path=None, seed=args.seed, - energy_key="energy", - forces_key="forces", - stress_key="stress", + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, virials_key=args.virials_key, dipole_key=args.dipole_key, charges_key=args.charges_key, + head_name="pt_head", keep_isolated_atoms=args.keep_isolated_atoms, ) - collections.train += collections_mp.train - collections.valid += collections_mp.valid - except Exception as exc: - raise RuntimeError( - "Model download failed and no local model found" - ) from exc + head_config_pt = HeadConfig( + head_name="pt_head", + train_file=args.pt_train_file, + valid_file=args.pt_valid_file, + E0s="foundation", + statistics_file=args.statistics_file, + valid_fraction=args.valid_fraction, + config_type_weights=None, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + collections=collections, + ) + head_config_pt.collections = collections + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + ) # Atomic number table # yapf: disable - if args.atomic_numbers is None: - assert args.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) - else: - if args.statistics_file is None: - logging.info("Using atomic numbers from command line argument") + for head_config in head_configs: + if head_config.atomic_numbers is None: + assert head_config.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" + z_table_head = tools.get_atomic_number_table_from_zs( + z + for configs in (head_config.collections.train, head_config.collections.valid) + for config in configs + for z in config.atomic_numbers + ) + head_config.atomic_numbers = z_table_head.zs + head_config.z_table = z_table_head else: - logging.info("Using atomic numbers from statistics file") - zs_list = ast.literal_eval(args.atomic_numbers) - assert isinstance(zs_list, list) - z_table = tools.get_atomic_number_table_from_zs(zs_list) - # yapf: enable + if head_config.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(head_config.atomic_numbers) + assert isinstance(zs_list, list) + z_table_head = tools.AtomicNumberTable(zs_list) + head_config.atomic_numbers = zs_list + head_config.z_table = z_table_head + # yapf: enable + all_atomic_numbers = set() + for head_config in head_configs: + all_atomic_numbers.update(head_config.atomic_numbers) + z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) logging.info(z_table) - if atomic_energies_dict is None or len(atomic_energies_dict) == 0: - if args.train_file.endswith(".xyz"): - atomic_energies_dict = get_atomic_energies( - args.E0s, collections.train, z_table, heads - ) - if args.E0s.lower() == "foundation": - assert args.foundation_model is not None - logging.info("Using atomic energies from foundation model") - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - atomic_energies_dict = { - z: model_foundation.atomic_energies_fn.atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } + # Atomic energies + atomic_energies_dict = {} + for head_config in head_configs: + if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + if head_config.train_file.endswith(".xyz") and head_config.E0s.lower() != "foundation": + atomic_energies_dict[head_config.head_name] = get_atomic_energies( + head_config.E0s, head_config.collections.train, head_config.z_table + ) + elif head_config.E0s.lower() == "foundation": + assert args.foundation_model is not None + logging.info("Using atomic energies from foundation model") + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + atomic_energies_dict[head_config.head_name] = { + z: model_foundation.atomic_energies_fn.atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + else: + atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: - atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table, heads) + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + + # Atomic energies for multiheads finetuning if args.multiheads_finetuning: assert ( model_foundation is not None @@ -343,7 +330,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) - atomic_energies_dict["pbe_mp"] = { + atomic_energies_dict["pt_head"] = { z: model_foundation.atomic_energies_fn.atomic_energies[ z_table_foundation.z_to_index(z) ].item() @@ -372,53 +359,52 @@ def run(args: argparse.Namespace) -> None: # atomic_energies: np.ndarray = np.array( # [atomic_energies_dict[z] for z in z_table.zs] # ) - atomic_energies = dict_to_array(atomic_energies_dict, args.heads) + atomic_energies = dict_to_array(atomic_energies_dict, heads) logging.info(f"Atomic energies: {atomic_energies.tolist()}") - if args.train_file.endswith(".xyz"): - train_set = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in collections.train - ] - valid_sets = {head: [] for head in heads} - for head in heads: - valid_sets[head] = [ + valid_sets = {head: [] for head in heads} + train_sets = {head: [] for head in heads} + for head_config in head_configs: + if head_config.train_file.endswith(".xyz"): + train_sets[head_config.head_name] = [ data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads ) - for config in collections.valid - if config.head == head + for config in head_config.collections.train ] + valid_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.valid + ] - elif args.train_file.endswith(".h5"): - train_set = data.HDF5Dataset( - args.train_file, r_max=args.r_max, z_table=z_table, heads=heads - ) - valid_set = data.HDF5Dataset( - args.valid_file, r_max=args.r_max, z_table=z_table, heads=heads - ) - else: # This case would be for when the file path is to a directory of multiple .h5 files - train_set = data.dataset_from_sharded_hdf5( - args.train_file, r_max=args.r_max, z_table=z_table, heads=heads - ) - # check if the folder has subfolders for each head by opening args.valid_file folder - if check_folder_subfolder(args.valid_file): - valid_sets = {} - for head in heads: - valid_sets[head] = data.dataset_from_sharded_hdf5( - os.path.join(args.valid_file, head), - r_max=args.r_max, - z_table=z_table, - heads=heads, - ) - else: - valid_set = data.dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table, heads=heads + elif args.train_file.endswith(".h5"): + train_sets[head_config.head_name] = data.HDF5Dataset( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads ) - valid_sets = {"Default": valid_set} - + valid_sets[head_config.head_name] = data.HDF5Dataset( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads + ) + else: # This case would be for when the file path is to a directory of multiple .h5 files + train_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads + ) + valid_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads + ) + train_loader_head = torch_geometric.dataloader.DataLoader( + dataset=train_sets[head_config.head_name], + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + head_config.train_loader = train_loader_head + # concatenate all the trainsets + train_set = ConcatDataset([train_sets[head] for head in heads]) train_sampler, valid_sampler = None, None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( @@ -535,8 +521,9 @@ def run(args: argparse.Namespace) -> None: model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] args.model = "FoundationMACE" - model_config_foundation["heads"] = args.heads + model_config_foundation["heads"] = heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") else: @@ -840,60 +827,78 @@ def run(args: argparse.Namespace) -> None: logging.info("Computing metrics for training, validation, and test sets") - all_data_loaders = { - "train": train_loader, - } + all_data_loaders = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + all_data_loaders[data_loader_name] = head_config.train_loader for head, valid_loader in valid_loaders.items(): - all_data_loaders[head] = valid_loader + data_load_name = "valid_" + head + all_data_loaders[data_load_name] = valid_loader test_sets = {} - if args.train_file.endswith(".xyz"): - for name, subset in collections.tests: - test_sets[name] = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in subset - ] - elif not args.multi_processed_test: - test_files = get_files_with_suffix(args.test_dir, "_test.h5") - for test_file in test_files: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table, heads=heads - ) - else: - test_folders = glob(args.test_dir + "/*") - for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table, heads=heads - ) + stop_first_test = False + # check if all head have same test set + if all( + head_config.test_file == head_configs[0].test_file + for head_config in head_configs + ) and head_configs[0].test_file is not None: + stop_first_test = True + if all( + head_config.test_dir == head_configs[0].test_dir + for head_config in head_configs + ) and head_configs[0].test_dir is not None: + stop_first_test = True + for head_config in head_configs: + if head_config.train_file.endswith(".xyz"): + for name, subset in head_config.collections.tests: + test_sets[name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in subset + ] + if head_config.test_dir is not None: + if not args.multi_processed_test: + test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") + for test_file in test_files: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.HDF5Dataset( + test_file, r_max=args.r_max, z_table=z_table, heads=heads + ) + else: + test_folders = glob(head_config.test_dir + "/*") + for folder in test_folders: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.dataset_from_sharded_hdf5( + folder, r_max=args.r_max, z_table=z_table, heads=heads + ) - for test_name, test_set in test_sets.items(): - test_sampler = None - if args.distributed: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - try: - drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 - drop_last = False - test_loader = torch_geometric.dataloader.DataLoader( - test_set, - batch_size=args.valid_batch_size, - shuffle=(test_sampler is None), - drop_last=drop_last, - num_workers=args.num_workers, - pin_memory=args.pin_memory, - ) - all_data_loaders[test_name] = test_loader + for test_name, test_set in test_sets.items(): + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + ) + all_data_loaders[test_name] = test_loader + if stop_first_test: + break for swa_eval in swas: epoch = checkpoint_handler.load_latest( diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 814a23e0..600b12d1 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -127,6 +127,8 @@ def from_config( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), num_classes=len(z_table), ) + print("HEADS", heads) + print("config.head", config.head) try: head = torch.tensor(heads.index(config.head), dtype=torch.long) except ValueError: diff --git a/mace/data/utils.py b/mace/data/utils.py index b9b2c950..44481e3e 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -206,6 +206,8 @@ def load_from_xyz( virials_key: str = "virials", dipole_key: str = "dipole", charges_key: str = "charges", + head_key: str = "head", + head_name: str = "Default", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: @@ -250,11 +252,13 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): + atoms.info[head_key] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) if isolated_atom_config: if energy_key in atoms.info.keys(): + print(atoms) atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ energy_key ] @@ -281,6 +285,7 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, ) return atomic_energies_dict, configs diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 107ea0dc..4cdfb0c3 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -242,7 +242,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--train_file", help="Training set file, format is .xyz or .h5", type=str, - required=True, + required=False, ) parser.add_argument( "--valid_file", @@ -333,9 +333,9 @@ def build_default_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--heads", - help="List of heads in the training set", + help="Dict of heads: containing individual files and E0s", type=str, - default='["Default"]', + default=None, required=False, ) parser.add_argument( @@ -828,7 +828,6 @@ def str2bool(value): return value if value.lower() in ("yes", "true", "t", "y", "1"): return True - elif value.lower() in ("no", "false", "f", "n", "0"): + if value.lower() in ("no", "false", "f", "n", "0"): return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 66681829..77833e1e 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -1,56 +1,80 @@ import argparse -from dataclasses import dataclass -from typing import Any, Dict, Optional +import logging +import os +import urllib.request +import dataclasses +from typing import Any, Dict, List, Optional, Union +import torch -@dataclass +from mace.cli.fine_tuning_select import select_samples +from mace.tools.scripts_utils import ( + SubsetCollection, + dict_to_namespace, + get_dataset_from_xyz, +) + + +@dataclasses.dataclass class HeadConfig: - train_file: Optional[str] - valid_file: Optional[str] - test_file: Optional[str] - E0s: Optional[Any] - statistics_file: Optional[str] - valid_fraction: Optional[float] - config_type_weights: Optional[Dict[str, float]] - energy_key: Optional[str] - forces_key: Optional[str] - stress_key: Optional[str] - virials_key: Optional[str] - dipole_key: Optional[str] - charges_key: Optional[str] - keep_isolated_atoms: Optional[bool] - atomic_numbers: Optional[Dict[str, int]] - mean: Optional[float] - std: Optional[float] - avg_num_neighbors: Optional[float] - compute_avg_num_neighbors: Optional[bool] + head_name: str + train_file: Optional[str] = None + valid_file: Optional[str] = None + test_file: Optional[str] = None + test_dir: Optional[str] = None + E0s: Optional[Any] = None + statistics_file: Optional[str] = None + valid_fraction: Optional[float] = None + config_type_weights: Optional[Dict[str, float]] = None + energy_key: Optional[str] = None + forces_key: Optional[str] = None + stress_key: Optional[str] = None + virials_key: Optional[str] = None + dipole_key: Optional[str] = None + charges_key: Optional[str] = None + keep_isolated_atoms: Optional[bool] = None + atomic_numbers: Optional[Union[List[int], List[str]]] = None + mean: Optional[float] = None + std: Optional[float] = None + avg_num_neighbors: Optional[float] = None + compute_avg_num_neighbors: Optional[bool] = None + collections: Optional[SubsetCollection] = None + train_loader: torch.utils.data.DataLoader = None + z_table: Optional[Any] = None + atomic_energies_dict: Optional[Dict[str, float]] = None + +def dict_head_to_dataclass( + head: Dict[str, Any], head_name: str, args: argparse.Namespace +) -> HeadConfig: -def dict_head_to_dataclass(head: Dict[str, Any]) -> HeadConfig: return HeadConfig( - train_file=head.get("train_file"), - valid_file=head.get("valid_file"), - test_file=head.get("test_file"), - E0s=head.get("E0s"), - statistics_file=head.get("statistics_file"), - valid_fraction=head.get("valid_fraction"), - config_type_weights=head.get("config_type_weights"), - energy_key=head.get("energy_key"), - forces_key=head.get("forces_key"), - stress_key=head.get("stress_key"), - virials_key=head.get("virials_key"), - dipole_key=head.get("dipole_key"), - charges_key=head.get("charges_key"), - keep_isolated_atoms=head.get("keep_isolated_atoms"), + head_name=head_name, + train_file=head.get("train_file", args.train_file), + valid_file=head.get("valid_file", args.valid_file), + test_file=head.get("test_file", None), + test_dir=head.get("test_dir", None), + E0s=head.get("E0s", args.E0s), + statistics_file=head.get("statistics_file", args.statistics_file), + valid_fraction=head.get("valid_fraction", args.valid_fraction), + config_type_weights=head.get("config_type_weights", args.config_type_weights), + energy_key=head.get("energy_key", args.energy_key), + forces_key=head.get("forces_key", args.forces_key), + stress_key=head.get("stress_key", args.stress_key), + virials_key=head.get("virials_key", args.virials_key), + dipole_key=head.get("dipole_key", args.dipole_key), + charges_key=head.get("charges_key", args.charges_key), + keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: return { - "Default": { + "default": { "train_file": args.train_file, "valid_file": args.valid_file, "test_file": args.test_file, + "test_dir": args.test_dir, "E0s": args.E0s, "statistics_file": args.statistics_file, "valid_fraction": args.valid_fraction, @@ -64,3 +88,88 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: "keep_isolated_atoms": args.keep_isolated_atoms, } } + + +def assemble_mp_data( + args: argparse.Namespace, tag: str, head_configs: List[HeadConfig] +) -> Dict[str, Any]: + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c for c in os.path.basename(descriptors_url) if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info(f"Materials Project descriptors to {cached_descriptors_path}") + dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path + msg = f"Using Materials Project dataset with {dataset_mp}" + logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) + config_pt_paths = [head.train_file for head in head_configs] + args_samples = { + "configs_pt": dataset_mp, + "configs_ft": config_pt_paths, + "num_samples": args.num_samples_pt, + "seed": args.seed, + "model": args.foundation_model, + "head_pt": "pbe_mp", + "head_ft": "Default", + "weight_pt": args.weight_pt_head, + "weight_ft": 1.0, + "filtering_type": "combination", + "output": f"mp_finetuning-{tag}.xyz", + "descriptors": descriptors_mp, + "subselect": args.subselect_pt, + "device": args.device, + "default_dtype": args.default_dtype, + } + select_samples(dict_to_namespace(args_samples)) + collections_mp, _ = get_dataset_from_xyz( + train_path=f"mp_finetuning-{tag}.xyz", + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + energy_key="energy", + forces_key="forces", + stress_key="stress", + head_name="pt_head", + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + ) + return collections_mp + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 649f5f9f..866c734c 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -15,9 +15,9 @@ import numpy as np import torch import torch.distributed -from torch.optim.swa_utils import SWALR, AveragedModel from e3nn import o3 from prettytable import PrettyTable +from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules from mace.tools import evaluate @@ -39,12 +39,14 @@ def get_dataset_from_xyz( test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, + head_name: str = "Default", energy_key: str = "energy", forces_key: str = "forces", stress_key: str = "stress", virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", + head_key: str = "head", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( @@ -56,8 +58,10 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=True, keep_isolated_atoms=keep_isolated_atoms, + head_name=head_name, ) logging.info( f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" @@ -72,7 +76,9 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=False, + head_name=head_name, ) logging.info( f"Loaded {len(valid_configs)} validation configurations from '{valid_path}'" @@ -97,7 +103,9 @@ def get_dataset_from_xyz( stress_key=stress_key, virials_key=virials_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=False, + head_name=head_name, ) # create list of tuples (config_type, list(Atoms)) test_configs = data.test_config_types(all_test_configs) @@ -289,7 +297,7 @@ def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: return model_load_yaml.to(map_location) -def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: +def get_atomic_energies(E0s, train_collection, z_table) -> dict: if E0s is not None: logging.info( "Atomic Energies not in training file, using command line argument E0s" @@ -302,7 +310,7 @@ def get_atomic_energies(E0s, train_collection, z_table, heads) -> dict: try: assert train_collection is not None atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table, heads + train_collection, z_table ) except Exception as e: raise RuntimeError( @@ -483,16 +491,16 @@ def custom_key(key): return (2, key) -def dict_to_array(data, heads): - if not all(isinstance(value, dict) for value in data.values()): - return np.array(list(data.values())) +def dict_to_array(input_data, heads): + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array(list(input_data.values())) unique_keys = set() - for inner_dict in data.values(): + for inner_dict in input_data.values(): unique_keys.update(inner_dict.keys()) unique_keys = list(unique_keys) sorted_keys = sorted([int(key) for key in unique_keys]) - result_array = np.zeros((len(data), len(sorted_keys))) - for _, (head_name, inner_dict) in enumerate(data.items()): + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): for key, value in inner_dict.items(): key_index = sorted_keys.index(int(key)) head_index = heads.index(head_name) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 47ed95df..4a7d32fb 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -6,11 +6,11 @@ import ase.io import numpy as np import pytest +import torch from ase import build from ase.atoms import Atoms from ase.calculators.test import gradient_test from ase.constraints import ExpCellFilter -import torch from mace.calculators import mace_mp, mace_off from mace.calculators.mace import MACECalculator diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 9908a323..deee6806 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -277,6 +277,114 @@ def test_run_train_no_stress(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es) +def test_run_train_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + fitting_configs_ccd = [] + for _, c in enumerate(fitting_configs): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + + c_ccd = c.copy() + c_ccd.info["head"] = "CCD" + fitting_configs_ccd.append(c_ccd) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w") as file: + file.write(yaml_str) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float32" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["config"] = tmp_path / "config.yaml" + mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + tmp_path / "MACE.model", device="cpu", default_dtype="float32" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 22/08/2024 on commit + ref_Es = [ + 0.0, + 0.0, + 0.1492728888988495, + 0.12760481238365173, + 0.18094804883003235, + 0.2017526775598526, + 0.09473809599876404, + 0.20055484771728516, + 0.1673969328403473, + 0.1053609699010849, + 0.29178786277770996, + 0.06670654565095901, + 0.09736010432243347, + 0.23458734154701233, + 0.09877493232488632, + -0.022957436740398407, + 0.2738725543022156, + 0.13694337010383606, + 0.12737643718719482, + -0.07650933414697647, + -0.012938144616782665, + 0.061228662729263306, + ] + assert np.allclose(Es, ref_Es) + + def test_run_train_foundation(tmp_path, fitting_configs): ase.io.write(tmp_path / "fit.xyz", fitting_configs) @@ -356,20 +464,35 @@ def test_run_train_foundation(tmp_path, fitting_configs): def test_run_train_foundation_multihead(tmp_path, fitting_configs): - fitting_configs_ = [] + fitting_configs_dft = [] + fitting_configs_mp2 = [] for i, c in enumerate(fitting_configs): if i % 2 == 0: c.info["head"] = "DFT" + fitting_configs_dft.append(c) else: c.info["head"] = "MP2" - fitting_configs_.append(c) - ase.io.write(tmp_path / "fit_multihead.xyz", fitting_configs_) - + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w") as file: + file.write(yaml_str) mace_params = _mace_params.copy() mace_params["valid_fraction"] = 0.1 mace_params["checkpoints_dir"] = str(tmp_path) mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit_multihead.xyz" + mace_params["config"] = tmp_path / "config.yaml" mace_params["loss"] = "weighted" mace_params["foundation_model"] = "small" mace_params["hidden_irreps"] = "128x0e" @@ -377,7 +500,6 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["default_dtype"] = "float32" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["heads"] = "['MP2','DFT']" mace_params["batch_size"] = 2 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" From a7c98adf03719a06fb1d837bb9990f3cf3828271 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:34:03 +0100 Subject: [PATCH 074/101] fix E0s parsing --- mace/calculators/mace.py | 1 - mace/cli/fine_tuning_select.py | 1 - mace/cli/run_train.py | 14 +++++++------- mace/data/atomic_data.py | 2 -- mace/data/hdf5_dataset.py | 2 ++ mace/data/utils.py | 1 - mace/modules/loss.py | 8 +++++++- mace/tools/scripts_utils.py | 2 +- 8 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 8afa034e..292b114b 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -235,7 +235,6 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = self._clone_batch(batch_base) - print("head", batch["head"]) node_heads = batch["head"][batch["batch"]] num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 288e820a..f3b7462f 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -305,7 +305,6 @@ def select_samples( idx_pt = np.random.choice( list(range(len(atoms_list_pt))), args.num_samples, replace=False ) - print("idx_pt", idx_pt) atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 5da11b6d..dab7eae0 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -317,7 +317,7 @@ def run(args: argparse.Namespace) -> None: else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: - atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + atomic_energies_dict[head_config.head_name] = head_config.E0s # Atomic energies for multiheads finetuning if args.multiheads_finetuning: @@ -381,17 +381,17 @@ def run(args: argparse.Namespace) -> None: elif args.train_file.endswith(".h5"): train_sets[head_config.head_name] = data.HDF5Dataset( - head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) valid_sets[head_config.head_name] = data.HDF5Dataset( - head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) else: # This case would be for when the file path is to a directory of multiple .h5 files train_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( - head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) valid_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( - head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) train_loader_head = torch_geometric.dataloader.DataLoader( dataset=train_sets[head_config.head_name], @@ -863,14 +863,14 @@ def run(args: argparse.Namespace) -> None: for test_file in test_files: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table, heads=heads + test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) else: test_folders = glob(head_config.test_dir + "/*") for folder in test_folders: name = os.path.splitext(os.path.basename(test_file))[0] test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table, heads=heads + folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) for test_name, test_set in test_sets.items(): diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 600b12d1..814a23e0 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -127,8 +127,6 @@ def from_config( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), num_classes=len(z_table), ) - print("HEADS", heads) - print("config.head", config.head) try: head = torch.tensor(heads.index(config.head), dtype=torch.long) except ValueError: diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index ce3a9b83..d0c1698a 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -67,6 +67,8 @@ def __getitem__(self, index): pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), ) + if config.head is None: + config.head = self.kwargs.get("head") atomic_data = AtomicData.from_config( config, z_table=self.z_table, diff --git a/mace/data/utils.py b/mace/data/utils.py index 44481e3e..9277d700 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -258,7 +258,6 @@ def load_from_xyz( ) if isolated_atom_config: if energy_key in atoms.info.keys(): - print(atoms) atomic_energies_dict[atoms.get_atomic_numbers()[0]] = atoms.info[ energy_key ] diff --git a/mace/modules/loss.py b/mace/modules/loss.py index b3421ef5..aebae2b4 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -273,12 +273,18 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] return ( self.energy_weight * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + self.forces_weight * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + + self.stress_weight + * self.huber_loss( + configs_weight * configs_stress_weight * ref["stress"], + configs_weight * configs_stress_weight * pred["stress"], + ) ) def __repr__(self): diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 866c734c..810ef789 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -328,7 +328,7 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: isinstance(value, dict) for value in atomic_energies_eval.values() ): - atomic_energies_dict = {"Default": atomic_energies_eval} + atomic_energies_dict = atomic_energies_eval else: atomic_energies_dict = atomic_energies_eval assert isinstance(atomic_energies_dict, dict) From c50de483f1619429bde3f218cfd8fdfe6c3a1b57 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:46:12 +0100 Subject: [PATCH 075/101] fix linter --- mace/tools/multihead_tools.py | 2 +- tests/test_run_train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 77833e1e..bc599007 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -1,8 +1,8 @@ import argparse +import dataclasses import logging import os import urllib.request -import dataclasses from typing import Any, Dict, List, Optional, Union import torch diff --git a/tests/test_run_train.py b/tests/test_run_train.py index deee6806..08ad1117 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -308,7 +308,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): for sub_key, sub_value in value.items(): yaml_str += f" {sub_key}: {sub_value}\n" filename = tmp_path / "config.yaml" - with open(filename, "w") as file: + with open(filename, "w", encoding="utf-8") as file: file.write(yaml_str) mace_params = _mace_params.copy() @@ -486,7 +486,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): for sub_key, sub_value in value.items(): yaml_str += f" {sub_key}: {sub_value}\n" filename = tmp_path / "config.yaml" - with open(filename, "w") as file: + with open(filename, "w", encoding="utf-8") as file: file.write(yaml_str) mace_params = _mace_params.copy() mace_params["valid_fraction"] = 0.1 From a15936970dbd4ba3177839acf3e4edcf91cb970a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 23 Aug 2024 14:51:40 +0100 Subject: [PATCH 076/101] fix E0s extraction from file --- mace/cli/run_train.py | 39 +++++++++++++++++++++-------------- mace/modules/utils.py | 1 - mace/tools/multihead_tools.py | 7 +++++++ setup.cfg | 1 + 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index dab7eae0..da7b0e0b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -165,11 +165,18 @@ def run(args: argparse.Namespace) -> None: head_config.std = statistics["std"] head_config.avg_num_neighbors = statistics["avg_num_neighbors"] head_config.compute_avg_num_neighbors = False - head_config.E0s = ( - statistics["atomic_energies"] - if not head_config.E0s.endswith(".json") - else head_config.E0s - ) + if isinstance(statistics["atomic_energies"], str) and statistics[ + "atomic_energies" + ].endswith(".json"): + with open(statistics["atomic_energies"], "r", format="utf-8") as f: + atomic_energies = json.load(f) + head_config.E0s = atomic_energies + head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) + else: + head_config.E0s = statistics["atomic_energies"] + head_config.atomic_energies_dict = ast.literal_eval( + statistics["atomic_energies"] + ) # Data preparation if head_config.train_file.endswith(".xyz"): @@ -198,12 +205,11 @@ def run(args: argparse.Namespace) -> None: ) head_config.collections = collections head_config.atomic_energies_dict = atomic_energies_dict - print("ATOMIC ENERGIES DICT", atomic_energies_dict) - head_configs.append(head_config) - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," - ) + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," + ) + head_configs.append(head_config) if args.multiheads_finetuning: logging.info( @@ -317,7 +323,7 @@ def run(args: argparse.Namespace) -> None: else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: - atomic_energies_dict[head_config.head_name] = head_config.E0s + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict # Atomic energies for multiheads finetuning if args.multiheads_finetuning: @@ -379,7 +385,7 @@ def run(args: argparse.Namespace) -> None: for config in head_config.collections.valid ] - elif args.train_file.endswith(".h5"): + elif head_config.train_file.endswith(".h5"): train_sets[head_config.head_name] = data.HDF5Dataset( head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) @@ -453,8 +459,8 @@ def run(args: argparse.Namespace) -> None: loss_fn = get_loss_fn(args, dipole_only, compute_dipole) logging.info(loss_fn) - - if args.compute_avg_num_neighbors: + if all([head_config.compute_avg_num_neighbors for head_config in head_configs]): + logging.info("Computing average number of neighbors") avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) if args.distributed: num_graphs = torch.tensor(len(train_loader.dataset)).to(device) @@ -466,6 +472,9 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = (num_neighbors / num_graphs).item() else: args.avg_num_neighbors = avg_num_neighbors + else: + assert not any(head_config.avg_num_neighbors is None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" + args.avg_num_neighbors = max([head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None]) logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") # Selecting outputs diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 36921b11..5f08c819 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -357,7 +357,6 @@ def _compute_mean_rms_energy_forces( def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: num_neighbors = [] - for batch in data_loader: _, receivers = batch.edge_index _, counts = torch.unique(receivers, return_counts=True) diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index bc599007..1e190da2 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -58,6 +58,13 @@ def dict_head_to_dataclass( statistics_file=head.get("statistics_file", args.statistics_file), valid_fraction=head.get("valid_fraction", args.valid_fraction), config_type_weights=head.get("config_type_weights", args.config_type_weights), + compute_avg_num_neighbors=head.get( + "compute_avg_num_neighbors", args.compute_avg_num_neighbors + ), + atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), + mean=head.get("mean", args.mean), + std=head.get("std", args.std), + avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), energy_key=head.get("energy_key", args.energy_key), forces_key=head.get("forces_key", args.forces_key), stress_key=head.get("stress_key", args.stress_key), diff --git a/setup.cfg b/setup.cfg index 842739dc..13d55161 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = python-hostlist configargparse GitPython + pyYAML tqdm # for plotting: matplotlib From f09a08f3cce3bb922e6d5bfdbcab05c8c2408ac6 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:20:44 +0100 Subject: [PATCH 077/101] fix avg_neighbors --- mace/cli/run_train.py | 6 +++++- mace/data/hdf5_dataset.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index da7b0e0b..b6c64e89 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -225,6 +225,8 @@ def run(args: argparse.Namespace) -> None: head_name="pt_head", E0s="foundation", statistics_file=args.statistics_file, + compute_avg_num_neighbors=False, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, ) collections = assemble_mp_data(args, tag, head_configs) head_config_pt.collections = collections @@ -264,6 +266,8 @@ def run(args: argparse.Namespace) -> None: charges_key=args.charges_key, keep_isolated_atoms=args.keep_isolated_atoms, collections=collections, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, + compute_avg_num_neighbors=False, ) head_config_pt.collections = collections logging.info( @@ -473,7 +477,7 @@ def run(args: argparse.Namespace) -> None: else: args.avg_num_neighbors = avg_num_neighbors else: - assert not any(head_config.avg_num_neighbors is None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" + assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" args.avg_num_neighbors = max([head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None]) logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index d0c1698a..477ccd3f 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -58,7 +58,6 @@ def __getitem__(self, index): dipole=unpack_value(subgrp["dipole"][()]), charges=unpack_value(subgrp["charges"][()]), weight=unpack_value(subgrp["weight"][()]), - head=unpack_value(subgrp["head"][()]), energy_weight=unpack_value(subgrp["energy_weight"][()]), forces_weight=unpack_value(subgrp["forces_weight"][()]), stress_weight=unpack_value(subgrp["stress_weight"][()]), From a0b2af34319eceaf3e89f8179fc4c5309d485268 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:37:33 +0100 Subject: [PATCH 078/101] remove tinyurls --- mace/calculators/foundations_models.py | 6 +++--- mace/cli/run_train.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index bc9ed654..6d29cb04 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -62,9 +62,9 @@ def mace_mp( try: # checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0 urls = dict( - small="https://tinyurl.com/46jrkm3v", # 2023-12-10-mace-128-L0_energy_epoch-249.model - medium="https://tinyurl.com/5yyxdm76", # 2023-12-03-mace-128-L1_epoch-199.model - large="https://tinyurl.com/5f5yavf3", # MACE_MPtrj_2022.9.model + small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model + medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model + large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model ) checkpoint_url = ( urls.get(model, urls["medium"]) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b6c64e89..c2dcfac0 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -168,7 +168,7 @@ def run(args: argparse.Namespace) -> None: if isinstance(statistics["atomic_energies"], str) and statistics[ "atomic_energies" ].endswith(".json"): - with open(statistics["atomic_energies"], "r", format="utf-8") as f: + with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: atomic_energies = json.load(f) head_config.E0s = atomic_energies head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) @@ -463,7 +463,7 @@ def run(args: argparse.Namespace) -> None: loss_fn = get_loss_fn(args, dipole_only, compute_dipole) logging.info(loss_fn) - if all([head_config.compute_avg_num_neighbors for head_config in head_configs]): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): logging.info("Computing average number of neighbors") avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) if args.distributed: @@ -478,7 +478,7 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = avg_num_neighbors else: assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" - args.avg_num_neighbors = max([head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None]) + args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None) logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") # Selecting outputs From 5e70bfef5b7c600e579648827291340e8f502f8f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Fri, 23 Aug 2024 20:28:34 +0100 Subject: [PATCH 079/101] revert config loss stress --- mace/modules/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index aebae2b4..2d6522d2 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -282,8 +282,8 @@ def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + self.stress_weight * self.huber_loss( - configs_weight * configs_stress_weight * ref["stress"], - configs_weight * configs_stress_weight * pred["stress"], + ref["stress"], + pred["stress"], ) ) From 5017b84b2122ef327d94ab4193beb5e749fbd29f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:16:47 +0100 Subject: [PATCH 080/101] add head selection in lammps --- mace/calculators/lammps_mace.py | 11 ++++++++++- mace/cli/create_lammps_model.py | 27 +++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 8ad7e984..ed9f0e2c 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -8,12 +8,20 @@ @compile_mode("script") class LAMMPS_MACE(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, **kwargs): super().__init__() self.model = model self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[0])), + dtype=torch.long, + ), + ) + for param in self.model.parameters(): param.requires_grad = False @@ -27,6 +35,7 @@ def forward( compute_displacement = False if compute_virials: compute_displacement = True + data["head"] = self.head out = self.model( data, training=False, diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 4cae618f..858b708d 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -4,15 +4,34 @@ from e3nn.util import jit from mace.calculators import LAMMPS_MACE +import argparse -def main(): - assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} model_path" +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default="default", + ) + return parser.parse_args() + - model_path = sys.argv[1] # takes model name as command-line input +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input + head = args.head model = torch.load(model_path) model = model.double().to("cpu") - lammps_model = LAMMPS_MACE(model) + lammps_model = LAMMPS_MACE(model, head=head) lammps_model_compiled = jit.compile(lammps_model) lammps_model_compiled.save(model_path + "-lammps.pt") From 1191c93d3d5b949a32e807926cf894b85b1191ac Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:20:14 +0100 Subject: [PATCH 081/101] fix imports --- mace/cli/create_lammps_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 858b708d..6cf2ee6f 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,10 +1,9 @@ -import sys +import argparse import torch from e3nn.util import jit from mace.calculators import LAMMPS_MACE -import argparse def parse_args(): From b0bbd5ba9e5c20addf25031d7c96ba138e70de88 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:27:17 +0100 Subject: [PATCH 082/101] change default head to last for lammps --- mace/calculators/lammps_mace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index ed9f0e2c..941312b0 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -17,7 +17,7 @@ def __init__(self, model, **kwargs): self.register_buffer( "head", torch.tensor( - self.model.heads.index(kwargs.get("head", self.model.heads[0])), + self.model.heads.index(kwargs.get("head", self.model.heads[-1])), dtype=torch.long, ), ) From 529b26142c4550526161ac735d841e8f16758180 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:35:17 +0100 Subject: [PATCH 083/101] fix head selection --- mace/calculators/lammps_mace.py | 2 +- mace/cli/create_lammps_model.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 941312b0..182dc283 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -19,7 +19,7 @@ def __init__(self, model, **kwargs): torch.tensor( self.model.heads.index(kwargs.get("head", self.model.heads[-1])), dtype=torch.long, - ), + ).unsqueeze(0), ) for param in self.model.parameters(): diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 6cf2ee6f..c023f640 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -19,7 +19,7 @@ def parse_args(): type=str, nargs="?", help="Head of the model to be converted to LAMMPS", - default="default", + default=None, ) return parser.parse_args() @@ -30,7 +30,9 @@ def main(): head = args.head model = torch.load(model_path) model = model.double().to("cpu") - lammps_model = LAMMPS_MACE(model, head=head) + lammps_model = ( + LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) + ) lammps_model_compiled = jit.compile(lammps_model) lammps_model_compiled.save(model_path + "-lammps.pt") From e25a1750cafd109139323e30b3e95125d8595caf Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 20:05:04 +0100 Subject: [PATCH 084/101] fix the lammps model backward comp --- mace/calculators/lammps_mace.py | 2 ++ mace/cli/create_lammps_model.py | 38 ++++++++++++++++++++++++++++++--- mace/modules/blocks.py | 7 +++--- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 182dc283..408dfaa8 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -14,6 +14,8 @@ def __init__(self, model, **kwargs): self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "head"): + model.heads = [None] self.register_buffer( "head", torch.tensor( diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index c023f640..2d12d245 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,8 +1,6 @@ import argparse - import torch from e3nn.util import jit - from mace.calculators import LAMMPS_MACE @@ -24,12 +22,46 @@ def parse_args(): return parser.parse_args() +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (default: {len(heads)}, press Enter to skip): " + ) + + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + elif selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + else: + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + def main(): args = parse_args() model_path = args.model_path # takes model name as command-line input - head = args.head model = torch.load(model_path) model = model.double().to("cpu") + + if args.head is None: + head = select_head(model) + else: + head = args.head + lammps_model = ( LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) ) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 48a9d22c..34539b0b 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -81,8 +81,9 @@ def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) - if hasattr(self, "num_heads") and self.num_heads > 1 and heads is not None: - x = mask_head(x, heads, self.num_heads) + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) return self.linear_2(x) # [n_nodes, len(heads)] @@ -620,7 +621,7 @@ def _setup(self) -> None: input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, + torch.nn.functional.silu, # gate ) # Linear From f2b97af8775201232455d537a988e1b87c1b53e5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:06:19 +0100 Subject: [PATCH 085/101] remove model path as option arg --- mace/cli/create_lammps_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 2d12d245..13ea3230 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -7,9 +7,8 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_path", + "model_path", type=str, - required=True, help="Path to the model to be converted to LAMMPS", ) parser.add_argument( From ae07b3c983392e9c28e5c2f98cf91b2647708e73 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:22:35 +0100 Subject: [PATCH 086/101] fix head selection --- mace/calculators/lammps_mace.py | 2 +- mace/cli/create_lammps_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 408dfaa8..4211c37f 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -14,7 +14,7 @@ def __init__(self, model, **kwargs): self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) - if not hasattr(model, "head"): + if not hasattr(model, "heads"): model.heads = [None] self.register_buffer( "head", diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 13ea3230..3133bd6e 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -37,7 +37,7 @@ def select_head(model): # Ask the user to select a head selected = input( - f"Select a head by number (default: {len(heads)}, press Enter to skip): " + f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " ) if selected.isdigit() and 1 <= int(selected) <= len(heads): From d63e90f23072a1542ab34b7480316dada218a996 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 27 Aug 2024 21:25:47 +0100 Subject: [PATCH 087/101] improve logging head lammps model --- mace/cli/create_lammps_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 3133bd6e..9341ca8b 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -60,6 +60,9 @@ def main(): head = select_head(model) else: head = args.head + print( + f"Selected head: {head} from command line in the list available heads: {model.heads}" + ) lammps_model = ( LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) From 70a3e2965faa7c71d85d7fc8efdfc5b1f7efb434 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:49:35 +0100 Subject: [PATCH 088/101] revert run train --- mace/cli/run_train.py | 222 +++++++----------------------------------- 1 file changed, 36 insertions(+), 186 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c9ad9ee6..c2dcfac0 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -64,7 +64,6 @@ def run(args: argparse.Namespace) -> None: """ This script runs the training/fine tuning for mace """ - args, input_log_messages = tools.check_args(args) tag = tools.get_tag(name=args.name, seed=args.seed) if args.device == "xpu": @@ -92,9 +91,6 @@ def run(args: argparse.Namespace) -> None: # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) - logging.info("===========VERIFYING SETTINGS===========") - for message, loglevel in input_log_messages: - logging.log(level=loglevel, msg=message) if args.distributed: torch.cuda.set_device(local_rank) @@ -105,7 +101,7 @@ def run(args: argparse.Namespace) -> None: logging.info(f"MACE version: {mace.__version__}") except AttributeError: logging.info("Cannot find MACE version, please install MACE via pip") - logging.debug(f"Configuration: {args}") + logging.info(f"Configuration: {args}") tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) @@ -277,43 +273,6 @@ def run(args: argparse.Namespace) -> None: logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" ) - logging.info("") - logging.info("===========LOADING INPUT DATA===========") - # Data preparation - if args.train_file.endswith(".xyz"): - if args.valid_file is not None: - assert args.valid_file.endswith( - ".xyz" - ), "valid_file if given must be same format as train_file" - config_type_weights = get_config_type_weights(args.config_type_weights) - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, - keep_isolated_atoms=args.keep_isolated_atoms, - ) - if len(collections.train) < args.batch_size: - logging.error( - f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" - ) - if len(collections.valid) < args.valid_batch_size: - logging.warning( - f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" - ) - args.valid_batch_size = len(collections.valid) - - else: - atomic_energies_dict = None # Atomic number table # yapf: disable @@ -343,7 +302,7 @@ def run(args: argparse.Namespace) -> None: for head_config in head_configs: all_atomic_numbers.update(head_config.atomic_numbers) z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) - logging.info(f"Atomic Numbers used: {z_table.zs}") + logging.info(z_table) # Atomic energies atomic_energies_dict = {} @@ -355,7 +314,8 @@ def run(args: argparse.Namespace) -> None: ) elif head_config.E0s.lower() == "foundation": assert args.foundation_model is not None - z_table_foundation = AtomicNumberTable( + logging.info("Using atomic energies from foundation model") + z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) atomic_energies_dict[head_config.head_name] = { @@ -364,10 +324,7 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" - ) - else: + else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict @@ -413,9 +370,7 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) - logging.info( - f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" - ) + logging.info(f"Atomic energies: {atomic_energies.tolist()}") valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} @@ -505,9 +460,9 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - logging.info("") - logging.info("===========MODEL DETAILS===========") + loss_fn = get_loss_fn(args, dipole_only, compute_dipole) + logging.info(loss_fn) if all(head_config.compute_avg_num_neighbors for head_config in head_configs): logging.info("Computing average number of neighbors") avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) @@ -524,22 +479,14 @@ def run(args: argparse.Namespace) -> None: else: assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None) - if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning( - f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}") + logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") # Selecting outputs compute_virials = False if args.loss in ("stress", "virials", "huber", "universal"): compute_virials = True args.compute_stress = True - if "MAE" in args.error_table: - args.error_table = "PerAtomMAEstressvirials" - else: - args.error_table = "PerAtomRMSEstressvirials" + args.error_table = "PerAtomRMSEstressvirials" output_args = { "energy": compute_energy, @@ -548,10 +495,7 @@ def run(args: argparse.Namespace) -> None: "stress": args.compute_stress, "dipoles": compute_dipole, } - - logging.info( - f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" - ) + logging.info(f"Selected the following outputs: {output_args}") if args.scaling == "no_scaling": args.std = 1.0 @@ -562,15 +506,12 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading FOUNDATION model") + logging.info("Building model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_energies"] = atomic_energies model_config_foundation["atomic_numbers"] = z_table.zs model_config_foundation["num_elements"] = len(z_table) args.max_L = model_config_foundation["hidden_irreps"].lmax - args.num_channels = list( - {irrep.mul for irrep in o3.Irreps(model_config_foundation["hidden_irreps"])} - )[0] if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) else: @@ -598,35 +539,23 @@ def run(args: argparse.Namespace) -> None: model_config_foundation["heads"] = heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({model_config_foundation['hidden_irreps']})" - ) - logging.info( - f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" - ) - logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" - ) - logging.info( - f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" - ) else: logging.info("Building model") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" - ) - logging.info( - f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" - ) - logging.info( - f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" - ) - logging.info( - f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" - ) - logging.info( - f"Distance transform for radial basis functions: {args.distance_transform}" - ) + if args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + model_config = dict( r_max=args.r_max, num_bessel=args.num_radial_basis, @@ -749,20 +678,6 @@ def run(args: argparse.Namespace) -> None: ) model.to(device) - logging.debug(model) - logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info("") - logging.info("===========OPTIMIZER INFORMATION===========") - logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") - logging.info(f"Batch size: {args.batch_size}") - if args.ema: - logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") - logging.info( - f"Number of gradient updates: {int(args.max_num_epochs*len(collections.train)/args.batch_size)}" - ) - logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - logging.info(loss_fn) - # Optimizer decay_interactions = {} no_decay_interactions = {} @@ -831,50 +746,6 @@ def run(args: argparse.Namespace) -> None: swa: Optional[tools.SWAContainer] = None swas = [False] if args.swa: - assert dipole_only is False, "Stage Two for dipole fitting not implemented" - swas.append(True) - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info( - f"Stage Two will start after {args.start_swa} epochs with loss function:" - ) - if args.loss == "forces_only": - raise ValueError("Can not select Stage Two with forces only loss.") - if args.loss == "virials": - loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - virials_weight=args.swa_virials_weight, - ) - elif args.loss == "stress": - loss_fn_energy = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - ) - elif args.loss == "energy_forces_dipole": - loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( - args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - dipole_weight=args.swa_dipole_weight, - ) - else: - loss_fn_energy = modules.WeightedEnergyForcesLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - ) - logging.info(loss_fn_energy) - swa = tools.SWAContainer( - model=AveragedModel(model), - scheduler=SWALR( - optimizer=optimizer, - swa_lr=args.swa_lr, - anneal_epochs=1, - anneal_strategy="linear", - ), - start=args.start_swa, - loss_fn=loss_fn_energy, - ) swa, swas = get_swa(args, model, optimizer, swas, dipole_only) checkpoint_handler = tools.CheckpointHandler( @@ -908,6 +779,10 @@ def run(args: argparse.Namespace) -> None: for group in optimizer.param_groups: group["lr"] = args.lr + logging.info(model) + logging.info(f"Number of parameters: {tools.count_parameters(model)}") + logging.info(f"Optimizer: {optimizer}") + if args.wandb: logging.info("Using Weights and Biases for logging") import wandb @@ -962,8 +837,7 @@ def run(args: argparse.Namespace) -> None: train_sampler=train_sampler, rank=rank, ) - logging.info("") - logging.info("===========RESULTS===========") + logging.info("Computing metrics for training, validation, and test sets") all_data_loaders = {} @@ -1039,13 +913,6 @@ def run(args: argparse.Namespace) -> None: if stop_first_test: break - train_valid_data_loader = { - k: v for k, v in all_data_loaders.items() if k in ["train", "valid"] - } - test_data_loader = { - k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"] - } - for swa_eval in swas: epoch = checkpoint_handler.load_latest( state=tools.CheckpointState(model, optimizer, lr_scheduler), @@ -1056,27 +923,13 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - if swa_eval: - logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") - else: - logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") + logging.info(f"Loaded model from epoch {epoch}") for param in model.parameters(): param.requires_grad = False - - table_train = create_error_table( - table_type=args.error_table, - all_data_loaders=train_valid_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - table_test = create_error_table( + table = create_error_table( table_type=args.error_table, - all_data_loaders=test_data_loader, + all_data_loaders=all_data_loaders, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -1084,8 +937,7 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) - logging.info("Error-table on TEST:\n" + str(table_test)) + logging.info("\n" + str(table)) if rank == 0: # Save entire model @@ -1105,9 +957,7 @@ def run(args: argparse.Namespace) -> None: } if swa_eval: torch.save( - model, Path(args.model_dir) / (args.name + "_stagetwo.model") - ) try: path_complied = Path(args.model_dir) / ( From e277d3279efce303c78218a0e2b34ede6bf46ec4 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 12:28:10 +0100 Subject: [PATCH 089/101] merge logging changes with develop Co-Authored-By: vue1999 <93268063+vue1999@users.noreply.github.com> --- mace/cli/run_train.py | 146 +++++++++++++++++++++++++++--------- mace/tools/scripts_utils.py | 16 ++-- 2 files changed, 122 insertions(+), 40 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c2dcfac0..11cfbeb0 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -65,6 +65,7 @@ def run(args: argparse.Namespace) -> None: This script runs the training/fine tuning for mace """ tag = tools.get_tag(name=args.name, seed=args.seed) + args, input_log_messages = tools.check_args(args) if args.device == "xpu": try: @@ -91,6 +92,9 @@ def run(args: argparse.Namespace) -> None: # Setup tools.set_seeds(args.seed) tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========VERIFYING SETTINGS===========") + for message, loglevel in input_log_messages: + logging.log(level=loglevel, msg=message) if args.distributed: torch.cuda.set_device(local_rank) @@ -101,7 +105,7 @@ def run(args: argparse.Namespace) -> None: logging.info(f"MACE version: {mace.__version__}") except AttributeError: logging.info("Cannot find MACE version, please install MACE via pip") - logging.info(f"Configuration: {args}") + logging.debug(f"Configuration: {args}") tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) @@ -147,6 +151,8 @@ def run(args: argparse.Namespace) -> None: args.heads = ast.literal_eval(args.heads) else: args.heads = prepare_default_head(args) + + logging.info("===========LOADING INPUT DATA===========") heads = list(args.heads.keys()) logging.info(f"Using heads: {heads}") head_configs: List[HeadConfig] = [] @@ -211,6 +217,22 @@ def run(args: argparse.Namespace) -> None: ) head_configs.append(head_config) + if all(head_config.train_file.endswith(".xyz") for head_config in head_configs): + size_collections_train = sum( + [len(head_config.collections.train) for head_config in head_configs] + ) + size_collections_valid = sum( + [len(head_config.collections.valid) for head_config in head_configs] + ) + if size_collections_train < args.batch_size: + logging.error( + f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" + ) + if size_collections_valid < args.valid_batch_size: + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" + ) + if args.multiheads_finetuning: logging.info( "==================Using multiheads finetuning mode==================" @@ -302,7 +324,7 @@ def run(args: argparse.Namespace) -> None: for head_config in head_configs: all_atomic_numbers.update(head_config.atomic_numbers) z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) - logging.info(z_table) + logging.info(f"Atomic Numbers used: {z_table.zs}") # Atomic energies atomic_energies_dict = {} @@ -314,7 +336,6 @@ def run(args: argparse.Namespace) -> None: ) elif head_config.E0s.lower() == "foundation": assert args.foundation_model is not None - logging.info("Using atomic energies from foundation model") z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) @@ -324,6 +345,9 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + logging.info( + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + ) else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -334,9 +358,6 @@ def run(args: argparse.Namespace) -> None: assert ( model_foundation is not None ), "Model foundation must be provided for multiheads finetuning" - logging.info( - "Using atomic energies from foundation model for multiheads finetuning" - ) z_table_foundation = AtomicNumberTable( [int(z) for z in model_foundation.atomic_numbers] ) @@ -346,6 +367,9 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + logging.info( + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + ) if args.model == "AtomicDipolesMACE": atomic_energies = None @@ -370,7 +394,9 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) - logging.info(f"Atomic energies: {atomic_energies.tolist()}") + logging.info( + f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" + ) valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} @@ -462,7 +488,7 @@ def run(args: argparse.Namespace) -> None: ) loss_fn = get_loss_fn(args, dipole_only, compute_dipole) - logging.info(loss_fn) + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): logging.info("Computing average number of neighbors") avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) @@ -479,7 +505,13 @@ def run(args: argparse.Namespace) -> None: else: assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None) - logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") + + if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: + logging.warning( + f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") # Selecting outputs compute_virials = False @@ -495,8 +527,10 @@ def run(args: argparse.Namespace) -> None: "stress": args.compute_stress, "dipoles": compute_dipole, } - logging.info(f"Selected the following outputs: {output_args}") - + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") if args.scaling == "no_scaling": args.std = 1.0 logging.info("No scaling selected") @@ -506,7 +540,7 @@ def run(args: argparse.Namespace) -> None: ) # Build model if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Building model") + logging.info("Loading FOUNDATION model") model_config_foundation = extract_config_mace_model(model_foundation) model_config_foundation["atomic_energies"] = atomic_energies model_config_foundation["atomic_numbers"] = z_table.zs @@ -539,17 +573,35 @@ def run(args: argparse.Namespace) -> None: model_config_foundation["heads"] = heads logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) else: logging.info("Building model") - if args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) assert ( len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" @@ -678,6 +730,20 @@ def run(args: argparse.Namespace) -> None: ) model.to(device) + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") + logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") + logging.info( + f"Number of gradient updates: {int(args.max_num_epochs*len(collections.train)/args.batch_size)}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + # Optimizer decay_interactions = {} no_decay_interactions = {} @@ -779,10 +845,6 @@ def run(args: argparse.Namespace) -> None: for group in optimizer.param_groups: group["lr"] = args.lr - logging.info(model) - logging.info(f"Number of parameters: {tools.count_parameters(model)}") - logging.info(f"Optimizer: {optimizer}") - if args.wandb: logging.info("Using Weights and Biases for logging") import wandb @@ -838,19 +900,21 @@ def run(args: argparse.Namespace) -> None: rank=rank, ) + logging.info("") + logging.info("===========RESULTS===========") logging.info("Computing metrics for training, validation, and test sets") - all_data_loaders = {} + train_valid_data_loader = {} for head_config in head_configs: data_loader_name = "train_" + head_config.head_name - all_data_loaders[data_loader_name] = head_config.train_loader + train_valid_data_loader[data_loader_name] = head_config.train_loader for head, valid_loader in valid_loaders.items(): data_load_name = "valid_" + head - all_data_loaders[data_load_name] = valid_loader + train_valid_data_loader[data_load_name] = valid_loader test_sets = {} stop_first_test = False - # check if all head have same test set + test_data_loader = {} if all( head_config.test_file == head_configs[0].test_file for head_config in head_configs @@ -909,7 +973,7 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, pin_memory=args.pin_memory, ) - all_data_loaders[test_name] = test_loader + test_data_loader[test_name] = test_loader if stop_first_test: break @@ -923,13 +987,27 @@ def run(args: argparse.Namespace) -> None: if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) model_to_evaluate = model if not args.distributed else distributed_model - logging.info(f"Loaded model from epoch {epoch}") + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") for param in model.parameters(): param.requires_grad = False - table = create_error_table( + table_train_valid = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + table_test = create_error_table( table_type=args.error_table, - all_data_loaders=all_data_loaders, + all_data_loaders=test_data_loader, model=model_to_evaluate, loss_fn=loss_fn, output_args=output_args, @@ -937,7 +1015,7 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - logging.info("\n" + str(table)) + logging.info("Error-table on TEST:\n" + str(table_test)) if rank == 0: # Save entire model diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 9b89b30e..5a19439c 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -418,11 +418,9 @@ def get_swa( args.start_swa = max(1, args.max_num_epochs // 4 * 3) else: if args.start_swa > args.max_num_epochs: - logging.info( + logging.warning( f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" ) - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info(f"Setting start Stage Two to {args.start_swa}") if args.loss == "forces_only": raise ValueError("Can not select Stage Two with forces only loss.") if args.loss == "virials": @@ -431,12 +429,18 @@ def get_swa( forces_weight=args.swa_forces_weight, virials_weight=args.swa_virials_weight, ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" + ) elif args.loss == "stress": loss_fn_energy = modules.WeightedEnergyForcesStressLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight, stress_weight=args.swa_stress_weight, ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.stress_weight} and learning rate : {args.swa_lr}" + ) elif args.loss == "energy_forces_dipole": loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( args.swa_energy_weight, @@ -444,7 +448,7 @@ def get_swa( dipole_weight=args.swa_dipole_weight, ) logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" ) elif args.loss == "universal": loss_fn_energy = modules.UniversalLoss( @@ -454,7 +458,7 @@ def get_swa( huber_delta=args.huber_delta, ) logging.info( - f"Using stochastic weight averaging (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" ) else: loss_fn_energy = modules.WeightedEnergyForcesLoss( @@ -462,7 +466,7 @@ def get_swa( forces_weight=args.swa_forces_weight, ) logging.info( - f"Stage Two (after {args.start_swa} epochs) with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" ) swa = SWAContainer( model=AveragedModel(model), From 85c5d9ad5dcdd6ab85b37db57d05dd00d2b123b7 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:26:54 +0100 Subject: [PATCH 090/101] fix linting --- mace/cli/create_lammps_model.py | 9 ++- mace/cli/run_train.py | 125 ++++++-------------------------- mace/modules/loss.py | 2 - mace/tools/__init__.py | 2 - mace/tools/multihead_tools.py | 1 + mace/tools/scripts_utils.py | 122 +++++++++++++++++++++++++++++++ mace/tools/utils.py | 20 ----- 7 files changed, 149 insertions(+), 132 deletions(-) diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 9341ca8b..3f647906 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,6 +1,8 @@ import argparse + import torch from e3nn.util import jit + from mace.calculators import LAMMPS_MACE @@ -42,12 +44,11 @@ def select_head(model): if selected.isdigit() and 1 <= int(selected) <= len(heads): return heads[int(selected) - 1] - elif selected == "": + if selected == "": print("No head selected. Proceeding without specifying a head.") return None - else: - print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") - return heads[-1] + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] def main(): diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 11cfbeb0..17b39667 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -41,12 +41,16 @@ dict_to_array, extract_config_mace_model, get_atomic_energies, + get_avg_num_neighbors, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, get_loss_fn, + get_optimizer, + get_params_options, get_swa, print_git_commit, + setup_wandb, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable @@ -194,6 +198,7 @@ def run(args: argparse.Namespace) -> None: head_config.config_type_weights ) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=head_config.train_file, valid_path=head_config.valid_file, valid_fraction=head_config.valid_fraction, @@ -219,10 +224,10 @@ def run(args: argparse.Namespace) -> None: if all(head_config.train_file.endswith(".xyz") for head_config in head_configs): size_collections_train = sum( - [len(head_config.collections.train) for head_config in head_configs] + len(head_config.collections.train) for head_config in head_configs ) size_collections_valid = sum( - [len(head_config.collections.valid) for head_config in head_configs] + len(head_config.collections.valid) for head_config in head_configs ) if size_collections_train < args.batch_size: logging.error( @@ -257,6 +262,7 @@ def run(args: argparse.Namespace) -> None: else: heads = list(dict.fromkeys(["pt_head"] + heads)) collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=args.pt_train_file, valid_path=args.pt_valid_file, valid_fraction=args.valid_fraction, @@ -367,8 +373,9 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } + atomic_energies_dict_pt = atomic_energies_dict["pt_head"] logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" + f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}" ) if args.model == "AtomicDipolesMACE": @@ -394,9 +401,14 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) - logging.info( - f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" + result = "\n".join( + [ + f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}" + for head_config in head_configs + ] ) + logging.info(result) valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} @@ -488,30 +500,7 @@ def run(args: argparse.Namespace) -> None: ) loss_fn = get_loss_fn(args, dipole_only, compute_dipole) - - if all(head_config.compute_avg_num_neighbors for head_config in head_configs): - logging.info("Computing average number of neighbors") - avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) - if args.distributed: - num_graphs = torch.tensor(len(train_loader.dataset)).to(device) - num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) - torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce( - num_neighbors, op=torch.distributed.ReduceOp.SUM - ) - args.avg_num_neighbors = (num_neighbors / num_graphs).item() - else: - args.avg_num_neighbors = avg_num_neighbors - else: - assert any(head_config.avg_num_neighbors is not None for head_config in head_configs), "Average number of neighbors must be provided in the configuration" - args.avg_num_neighbors = max(head_config.avg_num_neighbors for head_config in head_configs if head_config.avg_num_neighbors is not None) - - if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning( - f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {args.avg_num_neighbors}") + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) # Selecting outputs compute_virials = False @@ -745,61 +734,9 @@ def run(args: argparse.Namespace) -> None: logging.info(loss_fn) # Optimizer - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.interactions.named_parameters(): - if "linear.weight" in name or "skip_tp_full.weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.node_embedding.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": args.weight_decay, - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "products", - "params": model.products.parameters(), - "weight_decay": args.weight_decay, - }, - { - "name": "readouts", - "params": model.readouts.parameters(), - "weight_decay": 0.0, - }, - ], - lr=args.lr, - amsgrad=args.amsgrad, - betas=(args.beta, 0.999), - ) - + param_options = get_params_options(args, model) optimizer: torch.optim.Optimizer - if args.optimizer == "adamw": - optimizer = torch.optim.AdamW(**param_options) - elif args.optimizer == "schedulefree": - try: - from schedulefree import adamw_schedulefree - except ImportError as exc: - raise ImportError( - "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" - ) from exc - _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) - else: - optimizer = torch.optim.Adam(**param_options) + optimizer = get_optimizer(args, param_options) if args.device == "xpu": logging.info("Optimzing model and optimzier for XPU") model, optimizer = ipex.optimize(model, optimizer=optimizer) @@ -846,27 +783,7 @@ def run(args: argparse.Namespace) -> None: group["lr"] = args.lr if args.wandb: - logging.info("Using Weights and Biases for logging") - import wandb - - wandb_config = {} - args_dict = vars(args) - - for key, value in args_dict.items(): - if isinstance(value, np.ndarray): - args_dict[key] = value.tolist() - - args_dict_json = json.dumps(args_dict) - for key in args.wandb_log_hypers: - wandb_config[key] = args_dict[key] - tools.init_wandb( - project=args.wandb_project, - entity=args.wandb_entity, - name=args.wandb_name, - config=wandb_config, - directory=args.wandb_dir, - ) - wandb.run.summary["params"] = args_dict_json + setup_wandb(args) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 2d6522d2..91462d2c 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -273,8 +273,6 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] - configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ] - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] return ( self.energy_weight * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index ce4172fc..8ad80243 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -28,7 +28,6 @@ compute_rel_rmse, compute_rmse, get_atomic_number_table_from_zs, - get_optimizer, get_tag, setup_logger, ) @@ -46,7 +45,6 @@ "setup_logger", "get_tag", "count_parameters", - "get_optimizer", "MetricsLogger", "get_atomic_number_table_from_zs", "train", diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 1e190da2..8892fa6d 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -162,6 +162,7 @@ def assemble_mp_data( } select_samples(dict_to_namespace(args_samples)) collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, train_path=f"mp_finetuning-{tag}.xyz", valid_path=None, valid_fraction=args.valid_fraction, diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 5a19439c..025b3453 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -20,6 +20,7 @@ from torch.optim.swa_utils import SWALR, AveragedModel from mace import data, modules +from mace import tools from mace.tools import evaluate from mace.tools.train import SWAContainer @@ -349,6 +350,38 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: return atomic_energies_dict +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + def get_loss_fn( args: argparse.Namespace, dipole_only: bool, @@ -482,6 +515,95 @@ def get_swa( return swa, swas +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + args_dict_json = json.dumps(args_dict) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: return [ os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 762d9880..0d7aa41e 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -121,26 +121,6 @@ def atomic_numbers_to_indices( return to_index_fn(atomic_numbers) -def get_optimizer( - name: str, - amsgrad: bool, - learning_rate: float, - weight_decay: float, - parameters: Iterable[torch.Tensor], -) -> torch.optim.Optimizer: - if name == "adam": - return torch.optim.Adam( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - if name == "adamw": - return torch.optim.AdamW( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - raise RuntimeError(f"Unknown optimizer '{name}'") - - class UniversalEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, np.integer): From 6288b026e4335208bf35ca72707393e2936ad577 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:40:28 +0100 Subject: [PATCH 091/101] fix model_config lit --- mace/cli/run_train.py | 12 +++--------- mace/tools/scripts_utils.py | 3 +-- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 17b39667..1db77a06 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -350,10 +350,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" - ) + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -373,11 +370,7 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } - atomic_energies_dict_pt = atomic_energies_dict["pt_head"] - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict_pt[z]}' for z in z_table_foundation.zs])}" - ) - + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True @@ -560,6 +553,7 @@ def run(args: argparse.Namespace) -> None: args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] args.model = "FoundationMACE" model_config_foundation["heads"] = heads + model_config = model_config_foundation logging.info("Model configuration extracted from foundation model") logging.info("Using universal loss function for fine-tuning") logging.info( diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 025b3453..f44390a6 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -19,8 +19,7 @@ from prettytable import PrettyTable from torch.optim.swa_utils import SWALR, AveragedModel -from mace import data, modules -from mace import tools +from mace import data, modules, tools from mace.tools import evaluate from mace.tools.train import SWAContainer From 5d1315f23ccc2aba329f662b4b562bf2879291db Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:48:56 +0100 Subject: [PATCH 092/101] fix whites space formatting --- mace/cli/run_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1db77a06..59e76be3 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -350,7 +350,7 @@ def run(args: argparse.Namespace) -> None: z_table_foundation.z_to_index(z) ].item() for z in z_table.zs - } + } else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: @@ -370,7 +370,7 @@ def run(args: argparse.Namespace) -> None: ].item() for z in z_table.zs } - + if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True From 158d7b58667b034544ef3f45c7974041c5d16877 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 31 Aug 2024 16:28:42 +0100 Subject: [PATCH 093/101] fixed all tests --- mace/cli/run_train.py | 45 +++++++++++++++++---------------- mace/tools/arg_parser.py | 16 ++++++++++++ mace/tools/train.py | 2 +- tests/test_run_train.py | 54 +++++++++++++++++++++++----------------- 4 files changed, 72 insertions(+), 45 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 59e76be3..236072fa 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -246,7 +246,11 @@ def run(args: argparse.Namespace) -> None: if ( args.foundation_model in ["small", "medium", "large"] or "mp" in args.foundation_model + or args.pt_train_file is None ): + logging.info( + "Using foundation model for multiheads finetuning with Materials Project data" + ) heads = list(dict.fromkeys(["pt_head"] + heads)) head_config_pt = HeadConfig( head_name="pt_head", @@ -260,6 +264,9 @@ def run(args: argparse.Namespace) -> None: head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" head_configs.append(head_config_pt) else: + logging.info( + f"Using foundation model for multiheads finetuning with {args.pt_train_file}" + ) heads = list(dict.fromkeys(["pt_head"] + heads)) collections, atomic_energies_dict = get_dataset_from_xyz( work_dir=args.work_dir, @@ -394,14 +401,9 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) - result = "\n".join( - [ - f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + - "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}" - for head_config in head_configs - ] - ) - logging.info(result) + for head_config in head_configs: + logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") + valid_sets = {head: [] for head in heads} train_sets = {head: [] for head in heads} @@ -563,7 +565,7 @@ def run(args: argparse.Namespace) -> None: f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" ) logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" ) logging.info( f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" @@ -580,7 +582,7 @@ def run(args: argparse.Namespace) -> None: f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" ) logging.info( - f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" ) logging.info( f"Distance transform for radial basis functions: {args.distance_transform}" @@ -916,17 +918,18 @@ def run(args: argparse.Namespace) -> None: distributed=args.distributed, ) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - table_test = create_error_table( - table_type=args.error_table, - all_data_loaders=test_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - logging.info("Error-table on TEST:\n" + str(table_test)) + if not test_data_loader: + table_test = create_error_table( + table_type=args.error_table, + all_data_loaders=test_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TEST:\n" + str(table_test)) if rank == 0: # Save entire model diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index fd438990..046f04d6 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -331,6 +331,8 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=None, required=False, ) + + # Fine-tuning parser.add_argument( "--foundation_filter_elements", help="Filter element during fine-tuning", @@ -369,12 +371,26 @@ def build_default_arg_parser() -> argparse.ArgumentParser: choices=["fps", "random"], default="random", ) + parser.add_argument( + "--pt_train_file", + help="Training set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--pt_valid_file", + help="Validation set file for the pretrained head", + type=str, + default=None, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", type=str2bool, default=False, ) + + # Keys parser.add_argument( "--energy_key", help="Key of reference energies in training xyz", diff --git a/mace/tools/train.py b/mace/tools/train.py index 0dea4c32..0e347c3e 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -70,7 +70,7 @@ def valid_err_log( error_f = eval_metrics["rmse_f"] * 1e3 error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" diff --git a/tests/test_run_train.py b/tests/test_run_train.py index e3d42d47..9ae21e83 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -468,7 +468,15 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): fitting_configs_dft = [] fitting_configs_mp2 = [] for i, c in enumerate(fitting_configs): - if i % 2 == 0: + if i == 0 or i == 1: + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: c.info["head"] = "DFT" fitting_configs_dft.append(c) else: @@ -538,27 +546,27 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): print("Es", Es) # from a run on 20/08/2024 on commit ref_Es = [ - 1.4186015129089355, - 0.6012811660766602, - 1.4759466648101807, - 1.1662801504135132, - 1.117658019065857, - 1.4062559604644775, - 1.4638032913208008, - 0.9065879583358765, - 1.3814517259597778, - 1.2735612392425537, - 1.2472984790802002, - 1.1374807357788086, - 1.4028346538543701, - 1.0139431953430176, - 1.3830922842025757, - 1.0170294046401978, - 1.6741619110107422, - 1.2575324773788452, - 1.2426478862762451, - 1.0206304788589478, - 1.2309682369232178, - 1.135024070739746, + 1.7100412845611572, + 0.4079246520996094, + 0.9405305981636047, + 0.6698582768440247, + 0.8035168647766113, + 0.9693648219108582, + 0.9276483654975891, + 0.6999133229255676, + 0.8872858285903931, + 0.7888250946998596, + 0.7887940406799316, + 0.617201566696167, + 0.919223427772522, + 0.5526964664459229, + 0.8903999328613281, + 0.5131919980049133, + 1.1887052059173584, + 0.7943968176841736, + 0.7528715133666992, + 0.6052684187889099, + 0.743116557598114, + 0.6587133407592773, ] assert np.allclose(Es, ref_Es, atol=1e-1) From cc836913a1752717e25a75d3bb0680627b510756 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 31 Aug 2024 16:33:24 +0100 Subject: [PATCH 094/101] fix linter --- mace/cli/run_train.py | 2 +- tests/test_run_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 236072fa..62aac95b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -918,7 +918,7 @@ def run(args: argparse.Namespace) -> None: distributed=args.distributed, ) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - if not test_data_loader: + if not test_data_loader: table_test = create_error_table( table_type=args.error_table, all_data_loaders=test_data_loader, diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 9ae21e83..0cdb7cfa 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -468,7 +468,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): fitting_configs_dft = [] fitting_configs_mp2 = [] for i, c in enumerate(fitting_configs): - if i == 0 or i == 1: + if i in (0, 1): c_dft = c.copy() c_dft.info["head"] = "DFT" fitting_configs_dft.append(c_dft) From f505e1ee4f2962ea77472722c35a9013449c49ad Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Sat, 31 Aug 2024 17:18:23 +0100 Subject: [PATCH 095/101] float change in multihead test --- mace/cli/run_train.py | 2 +- mace/tools/multihead_tools.py | 4 ++- tests/test_run_train.py | 47 ++++++++++++++++++----------------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 62aac95b..3bd9a6d9 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -918,7 +918,7 @@ def run(args: argparse.Namespace) -> None: distributed=args.distributed, ) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - if not test_data_loader: + if test_data_loader: table_test = create_error_table( table_type=args.error_table, all_data_loaders=test_data_loader, diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 8892fa6d..ffde107f 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -180,4 +180,6 @@ def assemble_mp_data( ) return collections_mp except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc + raise RuntimeError( + "Model or descriptors download failed and no local model found" + ) from exc diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 0cdb7cfa..d13d0987 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -506,10 +506,11 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["foundation_model"] = "small" mace_params["hidden_irreps"] = "128x0e" mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float32" + mace_params["default_dtype"] = "float64" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" # make sure run_train.py is using the mace that is currently being tested @@ -546,27 +547,27 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): print("Es", Es) # from a run on 20/08/2024 on commit ref_Es = [ - 1.7100412845611572, - 0.4079246520996094, - 0.9405305981636047, - 0.6698582768440247, - 0.8035168647766113, - 0.9693648219108582, - 0.9276483654975891, - 0.6999133229255676, - 0.8872858285903931, - 0.7888250946998596, - 0.7887940406799316, - 0.617201566696167, - 0.919223427772522, - 0.5526964664459229, - 0.8903999328613281, - 0.5131919980049133, - 1.1887052059173584, - 0.7943968176841736, - 0.7528715133666992, - 0.6052684187889099, - 0.743116557598114, - 0.6587133407592773, + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, ] assert np.allclose(Es, ref_Es, atol=1e-1) From b6f6712fd36c5fda27d6eeed1d9c3b82d297b7de Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:43:25 +0100 Subject: [PATCH 096/101] fix valid head checkpoint --- mace/tools/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index 0e347c3e..1ab86f82 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -253,7 +253,6 @@ def train( output_args=output_args, device=device, ) - valid_loss += valid_loss_head if rank == 0: valid_err_log( valid_loss_head, @@ -272,7 +271,9 @@ def train( ], "valid_rmse_f": eval_metrics["rmse_f"], } - + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint + ) if log_wandb: wandb.log(wandb_log_dict) if rank == 0: From d0a36a18232028d752496467188dfbb5025b1e8f Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:56:41 +0100 Subject: [PATCH 097/101] simplify the run_train --- mace/cli/run_train.py | 244 ++----------------------------- mace/tools/model_script_utils.py | 229 +++++++++++++++++++++++++++++ mace/tools/utils.py | 1 - 3 files changed, 244 insertions(+), 230 deletions(-) create mode 100644 mace/tools/model_script_utils.py diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3bd9a6d9..936e6ce2 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -14,20 +14,18 @@ from pathlib import Path from typing import List, Optional -import numpy as np import torch.distributed import torch.nn.functional -from e3nn import o3 from e3nn.util import jit from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import ConcatDataset from torch_ema import ExponentialMovingAverage import mace -from mace import data, modules, tools +from mace import data, tools from mace.calculators.foundations_models import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( HeadConfig, assemble_mp_data, @@ -114,6 +112,7 @@ def run(args: argparse.Namespace) -> None: tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) commit = print_git_commit() + model_foundation: Optional[torch.nn.Module] = None if args.foundation_model is not None: if args.multiheads_finetuning: assert ( @@ -305,6 +304,7 @@ def run(args: argparse.Namespace) -> None: compute_avg_num_neighbors=False, ) head_config_pt.collections = collections + head_configs.append(head_config_pt) logging.info( f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" ) @@ -381,22 +381,22 @@ def run(args: argparse.Namespace) -> None: if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True - compute_dipole = True - compute_energy = False + args.compute_dipole = True + args.compute_energy = False args.compute_forces = False - compute_virials = False + args.compute_virials = False args.compute_stress = False else: dipole_only = False if args.model == "EnergyDipolesMACE": - compute_dipole = True - compute_energy = True + args.compute_dipole = True + args.compute_energy = True args.compute_forces = True - compute_virials = False + args.compute_virials = False args.compute_stress = False else: - compute_energy = True - compute_dipole = False + args.compute_energy = True + args.compute_dipole = False # atomic_energies: np.ndarray = np.array( # [atomic_energies_dict[z] for z in z_table.zs] # ) @@ -494,225 +494,11 @@ def run(args: argparse.Namespace) -> None: generator=torch.Generator().manual_seed(args.seed), ) - loss_fn = get_loss_fn(args, dipole_only, compute_dipole) + loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) - # Selecting outputs - compute_virials = False - if args.loss in ("stress", "virials", "huber", "universal"): - compute_virials = True - args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" - - output_args = { - "energy": compute_energy, - "forces": args.compute_forces, - "virials": compute_virials, - "stress": args.compute_stress, - "dipoles": compute_dipole, - } - logging.info( - f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" - ) - logging.info("===========MODEL DETAILS===========") - if args.scaling == "no_scaling": - args.std = 1.0 - logging.info("No scaling selected") - elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": - args.mean, args.std = modules.scaling_classes[args.scaling]( - train_loader, atomic_energies - ) - # Build model - if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading FOUNDATION model") - model_config_foundation = extract_config_mace_model(model_foundation) - model_config_foundation["atomic_energies"] = atomic_energies - model_config_foundation["atomic_numbers"] = z_table.zs - model_config_foundation["num_elements"] = len(z_table) - args.max_L = model_config_foundation["hidden_irreps"].lmax - if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": - model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) - else: - if isinstance(args.mean, np.ndarray): - if args.mean.size == 1: - model_config_foundation["atomic_inter_shift"] = args.mean.item() - elif args.mean.size == len(heads): - model_config_foundation["atomic_inter_shift"] = args.mean.tolist() - else: - logging.info( - "Mean not in correct format, using default value of 0.0" - ) - model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) - elif isinstance(args.mean, list) and len(args.mean) == len(heads): - model_config_foundation["atomic_inter_shift"] = args.mean - elif isinstance(args.mean, float): - model_config_foundation["atomic_inter_shift"] = [args.mean] * len(heads) - else: - logging.info("Mean not in correct format, using default value of 0.0") - model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) - - model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) - args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] - args.model = "FoundationMACE" - model_config_foundation["heads"] = heads - model_config = model_config_foundation - logging.info("Model configuration extracted from foundation model") - logging.info("Using universal loss function for fine-tuning") - logging.info( - f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" - ) - logging.info( - f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" - ) - logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" - ) - else: - logging.info("Building model") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" - ) - logging.info( - f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" - ) - logging.info( - f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" - ) - logging.info( - f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {args.distance_transform}" - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - logging.info(f"Hidden irreps: {args.hidden_irreps}") - - model_config = dict( - r_max=args.r_max, - num_bessel=args.num_radial_basis, - num_polynomial_cutoff=args.num_cutoff_basis, - max_ell=args.max_ell, - interaction_cls=modules.interaction_classes[args.interaction], - num_interactions=args.num_interactions, - num_elements=len(z_table), - hidden_irreps=o3.Irreps(args.hidden_irreps), - atomic_energies=atomic_energies, - avg_num_neighbors=args.avg_num_neighbors, - atomic_numbers=z_table.zs, - ) - - model: torch.nn.Module - - if args.model == "MACE": - model = modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=[0.0] * len(heads), - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - elif args.model == "ScaleShiftMACE": - model = modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - elif args.model == "FoundationMACE": - model = modules.ScaleShiftMACE(**model_config_foundation) - elif args.model == "ScaleShiftBOTNet": - model = modules.ScaleShiftBOTNet( - **model_config, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - ) - elif args.model == "BOTNet": - model = modules.BOTNet( - **model_config, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - elif args.model == "AtomicDipolesMACE": - # std_df = modules.scaling_classes["rms_dipoles_scaling"](train_loader) - assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" - assert ( - args.error_table == "DipoleRMSE" - ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" - model = modules.AtomicDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - # dipole_scale=1, - # dipole_shift=0, - ) - elif args.model == "EnergyDipolesMACE": - # std_df = modules.scaling_classes["rms_dipoles_scaling"](train_loader) - assert ( - args.loss == "energy_forces_dipole" - ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" - assert ( - args.error_table == "EnergyDipoleRMSE" - ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" - model = modules.EnergyDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - else: - raise RuntimeError(f"Unknown model: '{args.model}'") - - if args.foundation_model is not None: - if args.foundation_filter_elements: - model = load_foundations_elements( - model, - model_foundation, - z_table, - load_readout=True, - max_L=args.max_L, - ) - else: - model = load_foundations_elements( - model, - model_foundation, - z_table, - load_readout=False, - max_L=args.max_L, - ) + # Model + model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table) model.to(device) logging.debug(model) diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py new file mode 100644 index 00000000..8e8c2877 --- /dev/null +++ b/mace/tools/model_script_utils.py @@ -0,0 +1,229 @@ +import ast +import logging + +import numpy as np +from e3nn import o3 + +from mace import modules +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model + + +def configure_model( + args, train_loader, atomic_energies, model_foundation=None, heads=None, z_table=None +): + # Selecting outputs + compute_virials = args.loss in ("stress", "virials", "huber", "universal") + if compute_virials: + args.compute_stress = True + args.error_table = "PerAtomRMSEstressvirials" + + output_args = { + "energy": args.compute_energy, + "forces": args.compute_forces, + "virials": compute_virials, + "stress": args.compute_stress, + "dipoles": args.compute_dipole, + } + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") + + if args.scaling == "no_scaling": + args.std = 1.0 + logging.info("No scaling selected") + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) + + # Build model + if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: + logging.info("Loading FOUNDATION model") + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + args.max_L = model_config_foundation["hidden_irreps"].lmax + + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + else: + model_config_foundation["atomic_inter_shift"] = ( + _determine_atomic_inter_shift(args.mean, heads) + ) + + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] + args.model = "FoundationMACE" + model_config_foundation["heads"] = heads + model_config = model_config_foundation + + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) + else: + logging.info("Building model") + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + + model_config = dict( + r_max=args.r_max, + num_bessel=args.num_radial_basis, + num_polynomial_cutoff=args.num_cutoff_basis, + max_ell=args.max_ell, + interaction_cls=modules.interaction_classes[args.interaction], + num_interactions=args.num_interactions, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + ) + model_config_foundation = None + + model = _build_model(args, model_config, model_config_foundation, heads) + + if model_foundation is not None: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=args.foundation_filter_elements, + max_L=args.max_L, + ) + + return model, output_args + + +def _determine_atomic_inter_shift(mean, heads): + if isinstance(mean, np.ndarray): + if mean.size == 1: + return mean.item() + if mean.size == len(heads): + return mean.tolist() + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + if isinstance(mean, list) and len(mean) == len(heads): + return mean + if isinstance(mean, float): + return [mean] * len(heads) + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + + +def _build_model( + args, model_config, model_config_foundation, heads +): # pylint: disable=too-many-return-statements + if args.model == "MACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=[0.0] * len(heads), + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "ScaleShiftMACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "FoundationMACE": + return modules.ScaleShiftMACE(**model_config_foundation) + if args.model == "ScaleShiftBOTNet": + return modules.ScaleShiftBOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + ) + if args.model == "BOTNet": + return modules.BOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "AtomicDipolesMACE": + assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" + assert ( + args.error_table == "DipoleRMSE" + ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" + return modules.AtomicDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "EnergyDipolesMACE": + assert ( + args.loss == "energy_forces_dipole" + ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" + assert ( + args.error_table == "EnergyDipoleRMSE" + ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" + return modules.EnergyDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + raise RuntimeError(f"Unknown model: '{args.model}'") diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 0d7aa41e..28a77efe 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -141,7 +141,6 @@ def __init__(self, directory: str, tag: str) -> None: self.path = os.path.join(self.directory, self.filename) def log(self, d: Dict[str, Any]) -> None: - logging.debug(f"Saving info: {self.path}") os.makedirs(name=self.directory, exist_ok=True) with open(self.path, mode="a", encoding="utf-8") as f: f.write(json.dumps(d, cls=UniversalEncoder)) From 4e0cc274281b30315ab6139b4c4ef816c757893b Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:30:09 +0100 Subject: [PATCH 098/101] change float dtype in test_run_train_foundation --- tests/test_run_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index d13d0987..6c41ce0f 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -397,7 +397,7 @@ def test_run_train_foundation(tmp_path, fitting_configs): mace_params["foundation_model"] = "small" mace_params["hidden_irreps"] = "128x0e" mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float32" + mace_params["default_dtype"] = "float64" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["multiheads_finetuning"] = False @@ -427,7 +427,7 @@ def test_run_train_foundation(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float32" + tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] @@ -536,7 +536,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float32" + tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] From 3fb8b604f288d92f99dfa5fe1d315797a35262be Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:27:02 +0100 Subject: [PATCH 099/101] swap multihead test to float64 --- tests/test_run_train.py | 46 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 6c41ce0f..fe6c8c46 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -319,7 +319,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): mace_params["loss"] = "weighted" mace_params["hidden_irreps"] = "128x0e" mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float32" + mace_params["default_dtype"] = "float64" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["config"] = tmp_path / "config.yaml" @@ -349,7 +349,7 @@ def test_run_train_multihead(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float32" + tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] @@ -358,30 +358,30 @@ def test_run_train_multihead(tmp_path, fitting_configs): Es.append(at.get_potential_energy()) print("Es", Es) - # from a run on 22/08/2024 on commit + # from a run on 02/09/2024 on develop branch ref_Es = [ 0.0, 0.0, - 0.1492728888988495, - 0.12760481238365173, - 0.18094804883003235, - 0.2017526775598526, - 0.09473809599876404, - 0.20055484771728516, - 0.1673969328403473, - 0.1053609699010849, - 0.29178786277770996, - 0.06670654565095901, - 0.09736010432243347, - 0.23458734154701233, - 0.09877493232488632, - -0.022957436740398407, - 0.2738725543022156, - 0.13694337010383606, - 0.12737643718719482, - -0.07650933414697647, - -0.012938144616782665, - 0.061228662729263306, + 0.10637113905361611, + -0.012499594026624754, + 0.08983077108171753, + 0.21071322543112597, + -0.028921849222784398, + -0.02423359575741567, + 0.022923252188079057, + -0.02048334610058991, + 0.4349711162741364, + -0.04455577015569887, + -0.09765806785570091, + 0.16013134616829822, + 0.0758442928017698, + -0.05931856557011721, + 0.33964473532953265, + 0.134338442158641, + 0.18024119757783053, + -0.18914740992058765, + -0.06503477155294624, + 0.03436649147415213, ] assert np.allclose(Es, ref_Es) From 58f1f3a127756869332038be76f78915763da478 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 3 Sep 2024 12:23:26 +0100 Subject: [PATCH 100/101] fix printing of test --- mace/cli/run_train.py | 53 +++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 936e6ce2..7acafaa6 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -626,7 +626,9 @@ def run(args: argparse.Namespace) -> None: stop_first_test = True for head_config in head_configs: if head_config.train_file.endswith(".xyz"): + print(head_config.test_file) for name, subset in head_config.collections.tests: + print(name) test_sets[name] = [ data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads @@ -648,33 +650,33 @@ def run(args: argparse.Namespace) -> None: test_sets[name] = data.dataset_from_sharded_hdf5( folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name ) - - for test_name, test_set in test_sets.items(): - test_sampler = None - if args.distributed: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - try: - drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 - drop_last = False - test_loader = torch_geometric.dataloader.DataLoader( + for test_name, test_set in test_sets.items(): + print(test_name) + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( test_set, - batch_size=args.valid_batch_size, - shuffle=(test_sampler is None), - drop_last=drop_last, - num_workers=args.num_workers, - pin_memory=args.pin_memory, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, ) - test_data_loader[test_name] = test_loader - if stop_first_test: - break + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + ) + test_data_loader[test_name] = test_loader + if stop_first_test: + break for swa_eval in swas: epoch = checkpoint_handler.load_latest( @@ -704,6 +706,7 @@ def run(args: argparse.Namespace) -> None: distributed=args.distributed, ) logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + if test_data_loader: table_test = create_error_table( table_type=args.error_table, From 6d768c2703dc6c033eccab460ae5ddc04f13bdf9 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:28:15 +0100 Subject: [PATCH 101/101] fixed mixed pbc non pbc training --- mace/cli/fine_tuning_select.py | 30 +++++++++++++++--------- mace/data/atomic_data.py | 6 ++--- mace/data/neighborhood.py | 12 +++++----- mace/modules/loss.py | 42 ++++++++++++++++++++++------------ tests/test_data.py | 8 +++---- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index f3b7462f..2fa5f644 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -225,17 +225,25 @@ def select_samples( "Filtering configurations based on the finetuning set, " f"filtering type: combinations, elements: {all_species_ft}" ) - if args.descriptors is not None: - logging.info("Loading descriptors") - descriptors = np.load(args.descriptors, allow_pickle=True) - atoms_list_pt = ase.io.read(args.configs_pt, index=":") - for i, atoms in enumerate(atoms_list_pt): - atoms.info["mace_descriptors"] = descriptors[i] - atoms_list_pt_filtered = [ - x - for x in atoms_list_pt - if filter_atoms(x, all_species_ft, "combinations") - ] + if args.subselect != "random": + if args.descriptors is not None: + logging.info("Loading descriptors") + descriptors = np.load(args.descriptors, allow_pickle=True) + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] else: atoms_list_pt = ase.io.read(args.configs_pt, index=":") atoms_list_pt_filtered = [ diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 814a23e0..cb4edd94 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -119,7 +119,7 @@ def from_config( ) -> "AtomicData": if heads is None: heads = ["default"] - edge_index, shifts, unit_shifts = get_neighborhood( + edge_index, shifts, unit_shifts, cell = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) @@ -133,8 +133,8 @@ def from_config( head = torch.tensor(len(heads) - 1, dtype=torch.long) cell = ( - torch.tensor(config.cell, dtype=torch.get_default_dtype()) - if config.cell is not None + torch.tensor(cell, dtype=torch.get_default_dtype()) + if cell is not None else torch.tensor( 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() ).view(3, 3) diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 293576af..21296fa6 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -27,18 +27,18 @@ def get_neighborhood( max_positions = np.max(np.absolute(positions)) + 1 # Extend cell in non-periodic directions # For models with more than 5 layers, the multiplicative constant needs to be increased. - temp_cell = np.copy(cell) + # temp_cell = np.copy(cell) if not pbc_x: - temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] if not pbc_y: - temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] if not pbc_z: - temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] sender, receiver, unit_shifts = neighbour_list( quantities="ijS", pbc=pbc, - cell=temp_cell, + cell=cell, positions=positions, cutoff=cutoff, # self_interaction=True, # we want edges from atom to itself in different periodic images @@ -63,4 +63,4 @@ def get_neighborhood( # D = positions[j]-positions[i]+S.dot(cell) shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - return edge_index, shifts, unit_shifts + return edge_index, shifts, unit_shifts, cell diff --git a/mace/modules/loss.py b/mace/modules/loss.py index 91462d2c..a7e28c55 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -114,34 +114,34 @@ def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: def conditional_huber_forces( - ref: Batch, pred: TensorDict, huber_delta: float + ref_forces: Batch, pred_forces: TensorDict, huber_delta: float ) -> torch.Tensor: # Define the multiplication factors for each condition factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) # Apply multiplication factors based on conditions - c1 = torch.norm(ref["forces"], dim=-1) < 100 - c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( - torch.norm(ref["forces"], dim=-1) < 200 + c1 = torch.norm(ref_forces, dim=-1) < 100 + c2 = (torch.norm(ref_forces, dim=-1) >= 100) & ( + torch.norm(ref_forces, dim=-1) < 200 ) - c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( - torch.norm(ref["forces"], dim=-1) < 300 + c3 = (torch.norm(ref_forces, dim=-1) >= 200) & ( + torch.norm(ref_forces, dim=-1) < 300 ) c4 = ~(c1 | c2 | c3) - se = torch.zeros_like(pred["forces"]) + se = torch.zeros_like(pred_forces) se[c1] = torch.nn.functional.huber_loss( - ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] + ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] ) se[c2] = torch.nn.functional.huber_loss( - ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] + ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] ) se[c3] = torch.nn.functional.huber_loss( - ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] + ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] ) se[c4] = torch.nn.functional.huber_loss( - ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] + ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] ) return torch.mean(se) @@ -273,15 +273,27 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] + configs_energy_weight = ref.energy_weight # [n_graphs, ] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) return ( self.energy_weight - * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + * self.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + ) + self.forces_weight - * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) + * conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ) + self.stress_weight * self.huber_loss( - ref["stress"], - pred["stress"], + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], ) ) diff --git a/tests/test_data.py b/tests/test_data.py index e893f03c..9e0c49e6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -144,7 +144,7 @@ def test_basic(self): ] ) - indices, shifts, unit_shifts = get_neighborhood(positions, cutoff=1.5) + indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) assert indices.shape == (2, 4) assert shifts.shape == (4, 3) assert unit_shifts.shape == (4, 3) @@ -158,7 +158,7 @@ def test_signs(self): ) cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - edge_index, shifts, unit_shifts = get_neighborhood( + edge_index, shifts, unit_shifts, _ = get_neighborhood( positions, cutoff=3.5, pbc=(True, False, False), cell=cell ) num_edges = 10 @@ -172,7 +172,7 @@ def test_periodic_edge(): atoms = ase.build.bulk("Cu", "fcc") dist = np.linalg.norm(atoms.cell[0]).item() config = config_from_atoms(atoms) - edge_index, shifts, _ = get_neighborhood( + edge_index, shifts, _, _ = get_neighborhood( config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell ) sender, receiver = edge_index @@ -190,7 +190,7 @@ def test_half_periodic(): atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) assert all(atoms.pbc == (True, True, False)) config = config_from_atoms(atoms) # first shell dist is 2.864A - edge_index, shifts, _ = get_neighborhood( + edge_index, shifts, _, _ = get_neighborhood( config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell ) sender, receiver = edge_index