diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..998ae250 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,46 @@ +name: Linting and code formatting + +on: [] + # Trigger the workflow on push or pull request, + # but only for the main branch + # push: + # branches: [] + # pull_request: + # branches: [] + + +jobs: + build-linux: + runs-on: ubuntu-latest + + steps: + # Setup + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.8.10 + - name: Get cache + uses: actions/cache@v2 + with: + path: /opt/hostedtoolcache/Python/3.8.10/x64/lib/python3.8/site-packages + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + + # Install packages + - name: Install packages required for installation + run: python -m pip install --upgrade pip setuptools wheel + - name: Install dependencies + run: pip install -r requirements.txt + + # Check code + - name: Check formatting with yapf + run: python -m yapf --style=.style.yapf --diff --recursive . +# - name: Lint with flake8 +# run: flake8 --config=.flake8 . +# - name: Check type annotations with mypy +# run: mypy --config-file=.mypy.ini . + + - name: Test with pytest + run: python -m pytest tests diff --git a/.gitignore b/.gitignore index 8b5f3f64..3817d9f3 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,9 @@ dist/ # DS_Store .DS_Store +*.models +*.pt /wandb +*.xyz /checkpoints *.model diff --git a/mace/__version__.py b/mace/__version__.py index d7b30e12..47e8e016 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1 +1,3 @@ -__version__ = "0.3.6" +__version__ = "0.3.7" + +__all__ = ["__version__"] diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index bc9ed654..6d29cb04 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -62,9 +62,9 @@ def mace_mp( try: # checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0 urls = dict( - small="https://tinyurl.com/46jrkm3v", # 2023-12-10-mace-128-L0_energy_epoch-249.model - medium="https://tinyurl.com/5yyxdm76", # 2023-12-03-mace-128-L1_epoch-199.model - large="https://tinyurl.com/5f5yavf3", # MACE_MPtrj_2022.9.model + small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model + medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model + large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model ) checkpoint_url = ( urls.get(model, urls["medium"]) diff --git a/mace/calculators/foundations_models/mp_vasp_e0.json b/mace/calculators/foundations_models/mp_vasp_e0.json new file mode 100644 index 00000000..01771879 --- /dev/null +++ b/mace/calculators/foundations_models/mp_vasp_e0.json @@ -0,0 +1,91 @@ +{ + "pbe": { + "1": -1.11734008, + "2": 0.00096759, + "3": -0.29754725, + "4": -0.01781697, + "5": -0.26885011, + "6": -1.26173507, + "7": -3.12438806, + "8": -1.54838784, + "9": -0.51882044, + "10": -0.01241601, + "11": -0.22883163, + "12": -0.00951015, + "13": -0.21630193, + "14": -0.8263903, + "15": -1.88816619, + "16": -0.89160769, + "17": -0.25828273, + "18": -0.04925973, + "19": -0.22697913, + "20": -0.0927795, + "21": -2.11396364, + "22": -2.50054871, + "23": -3.70477179, + "24": -5.60261985, + "25": -5.32541181, + "26": -3.52004933, + "27": -1.93555024, + "28": -0.9351969, + "29": -0.60025846, + "30": -0.1651332, + "31": -0.32990651, + "32": -0.77971828, + "33": -1.68367812, + "34": -0.76941032, + "35": -0.22213843, + "36": -0.0335879, + "37": -0.1881724, + "38": -0.06826294, + "39": -2.17084228, + "40": -2.28579303, + "41": -3.13429753, + "42": -4.60211419, + "43": -3.45201492, + "44": -2.38073513, + "45": -1.46855515, + "46": -1.4773126, + "47": -0.33954585, + "48": -0.16843877, + "49": -0.35470981, + "50": -0.83642657, + "51": -1.41101987, + "52": -0.65740879, + "53": -0.18964571, + "54": -0.00857582, + "55": -0.13771876, + "56": -0.03457659, + "57": -0.45580806, + "58": -1.3309175, + "59": -0.29671824, + "60": -0.30391193, + "61": -0.30898427, + "62": -0.25470891, + "63": -8.38001538, + "64": -10.38896525, + "65": -0.3059505, + "66": -0.30676216, + "67": -0.30874667, + "69": -0.25190039, + "70": -0.06431414, + "71": -0.31997586, + "72": -3.52770927, + "73": -3.54492209, + "75": -4.70108713, + "76": -2.88257209, + "77": -1.46779304, + "78": -0.50269936, + "79": -0.28801193, + "80": -0.12454674, + "81": -0.31737194, + "82": -0.77644932, + "83": -1.32627283, + "89": -0.26827152, + "90": -0.90817426, + "91": -2.47653193, + "92": -4.90438537, + "93": -7.63378961, + "94": -10.77237713 + } +} \ No newline at end of file diff --git a/mace/calculators/lammps_mace.py b/mace/calculators/lammps_mace.py index 8ad7e984..4211c37f 100644 --- a/mace/calculators/lammps_mace.py +++ b/mace/calculators/lammps_mace.py @@ -8,12 +8,22 @@ @compile_mode("script") class LAMMPS_MACE(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, **kwargs): super().__init__() self.model = model self.register_buffer("atomic_numbers", model.atomic_numbers) self.register_buffer("r_max", model.r_max) self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "heads"): + model.heads = [None] + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[-1])), + dtype=torch.long, + ).unsqueeze(0), + ) + for param in self.model.parameters(): param.requires_grad = False @@ -27,6 +37,7 @@ def forward( compute_displacement = False if compute_virials: compute_displacement = True + data["head"] = self.head out = self.model( data, training=False, diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 05333200..292b114b 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -145,6 +145,10 @@ def __init__( [int(z) for z in self.models[0].atomic_numbers] ) self.charges_key = charges_key + try: + self.heads = self.models[0].heads + except AttributeError: + self.heads = ["Default"] model_dtype = get_model_dtype(self.models[0]) if default_dtype == "": print( @@ -198,7 +202,7 @@ def _atoms_to_batch(self, atoms): data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads ) ], batch_size=1, @@ -231,7 +235,11 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): if self.model_type in ["MACE", "EnergyDipoleMACE"]: batch = self._clone_batch(batch_base) - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"]) + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_heads + ] compute_stress = not self.use_compile else: compute_stress = False diff --git a/mace/cli/active_learning_md.py b/mace/cli/active_learning_md.py index 52e3879e..648a30b2 100644 --- a/mace/cli/active_learning_md.py +++ b/mace/cli/active_learning_md.py @@ -144,7 +144,6 @@ def main() -> None: def run(args: argparse.Namespace) -> None: - mace_fname = args.model atoms_fname = args.config atoms_index = args.config_index diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 4cae618f..3f647906 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -1,4 +1,4 @@ -import sys +import argparse import torch from e3nn.util import jit @@ -6,13 +6,68 @@ from mace.calculators import LAMMPS_MACE -def main(): - assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} model_path" +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "model_path", + type=str, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default=None, + ) + return parser.parse_args() + + +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " + ) - model_path = sys.argv[1] # takes model name as command-line input + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + if selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input model = torch.load(model_path) model = model.double().to("cpu") - lammps_model = LAMMPS_MACE(model) + + if args.head is None: + head = select_head(model) + else: + head = args.head + print( + f"Selected head: {head} from command line in the list available heads: {model.heads}" + ) + + lammps_model = ( + LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model) + ) lammps_model_compiled = jit.compile(lammps_model) lammps_model_compiled.save(model_path + "-lammps.pt") diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index bf53ef88..b5700bc4 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -62,7 +62,6 @@ def main() -> None: def run(args: argparse.Namespace) -> None: - torch_tools.set_default_dtype(args.default_dtype) device = torch_tools.init_device(args.device) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py new file mode 100644 index 00000000..2fa5f644 --- /dev/null +++ b/mace/cli/fine_tuning_select.py @@ -0,0 +1,346 @@ +########################################################################################### +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import logging +from typing import List + +import ase.data +import ase.io +import numpy as np +import torch + +from mace.calculators import MACECalculator, mace_mp + +try: + import fpsample +except ImportError: + pass + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--configs_pt", + help="path to XYZ configurations for the pretraining", + required=True, + ) + parser.add_argument( + "--configs_ft", + help="path or list of paths to XYZ configurations for the finetuning", + required=True, + ) + parser.add_argument( + "--num_samples", + help="number of samples to select for the pretraining", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--subselect", + help="method to subselect the configurations of the pretraining set", + type=str, + choices=["fps", "random"], + default="fps", + ) + parser.add_argument( + "--model", help="path to model", default="small", required=False + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--descriptors", help="path to descriptors", required=False, default=None + ) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--head_pt", + help="level of head for the pretraining set", + type=str, + default=None, + ) + parser.add_argument( + "--head_ft", + help="level of head for the finetuning set", + type=str, + default=None, + ) + parser.add_argument( + "--filtering_type", + help="filtering type", + type=str, + choices=[None, "combinations", "exclusive", "inclusive"], + default=None, + ) + parser.add_argument( + "--weight_ft", + help="weight for the finetuning set", + type=float, + default=1.0, + ) + parser.add_argument( + "--weight_pt", + help="weight for the pretraining set", + type=float, + default=1.0, + ) + parser.add_argument("--seed", help="random seed", type=int, default=42) + return parser.parse_args() + + +def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: + logging.info("Calculating descriptors") + for mol in atoms: + descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) + # average descriptors over atoms for each element + descriptors_dict = { + element: np.mean(descriptors[mol.symbols == element], axis=0) + for element in np.unique(mol.symbols) + } + mol.info["mace_descriptors"] = descriptors_dict + + +def filter_atoms( + atoms: ase.Atoms, element_subset: List[str], filtering_type: str +) -> bool: + """ + Filters atoms based on the provided filtering type and element subset. + + Parameters: + atoms (ase.Atoms): The atoms object to filter. + element_subset (list): The list of elements to consider during filtering. + filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'. + 'none' - No filtering is applied. + 'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. + 'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise. + 'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + + Returns: + bool: True if the atoms pass the filter, False otherwise. + """ + if filtering_type == "none": + return True + if filtering_type == "combinations": + atom_symbols = np.unique(atoms.symbols) + return all( + x in element_subset for x in atom_symbols + ) # atoms must *only* contain elements in the subset + if filtering_type == "exclusive": + atom_symbols = set(list(atoms.symbols)) + return atom_symbols == set(element_subset) + if filtering_type == "inclusive": + atom_symbols = np.unique(atoms.symbols) + return all( + x in atom_symbols for x in element_subset + ) # atoms must *at least* contain elements in the subset + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." + ) + + +class FPS: + def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): + self.n_samples = n_samples + self.atoms_list = atoms_list + self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) + self.species_dict = {x: i for i, x in enumerate(self.species)} + # start from a random configuration + self.list_index = [np.random.randint(0, len(atoms_list))] + self.assemble_descriptors() + + def run( + self, + ) -> List[int]: + """ + Run the farthest point sampling algorithm. + """ + descriptor_dataset_reshaped = ( + self.descriptors_dataset.reshape( # pylint: disable=E1121 + (len(self.atoms_list), -1) + ) + ) + logging.info(f"{descriptor_dataset_reshaped.shape}") + logging.info(f"n_samples: {self.n_samples}") + self.list_index = fpsample.fps_npdu_kdtree_sampling( + descriptor_dataset_reshaped, + self.n_samples, + ) + return self.list_index + + def assemble_descriptors(self) -> np.ndarray: + """ + Assemble the descriptors for all the configurations. + """ + self.descriptors_dataset: np.ndarray = 10e10 * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), + ), + dtype=np.float32, + ).astype(np.float32) + + for i, atoms in enumerate(self.atoms_list): + descriptors = atoms.info["mace_descriptors"] + for z in descriptors: + self.descriptors_dataset[i, self.species_dict[z]] = np.array( + descriptors[z] + ).astype(np.float32) + + +def select_samples( + args: argparse.Namespace, +) -> None: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.model in ["small", "medium", "large"]: + calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) + else: + calc = MACECalculator( + model_paths=args.model, device=args.device, default_dtype=args.default_dtype + ) + if isinstance(args.configs_ft, str): + atoms_list_ft = ase.io.read(args.configs_ft, index=":") + else: + atoms_list_ft = [] + for path in args.configs_ft: + atoms_list_ft += ase.io.read(path, index=":") + + if args.filtering_type is not None: + all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) + logging.info( + "Filtering configurations based on the finetuning set, " + f"filtering type: combinations, elements: {all_species_ft}" + ) + if args.subselect != "random": + if args.descriptors is not None: + logging.info("Loading descriptors") + descriptors = np.load(args.descriptors, allow_pickle=True) + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + atoms_list_pt_filtered = [ + x + for x in atoms_list_pt + if filter_atoms(x, all_species_ft, "combinations") + ] + if len(atoms_list_pt_filtered) <= args.num_samples: + logging.info( + f"Number of configurations after filtering {len(atoms_list_pt_filtered)} " + f"is less than the number of samples {args.num_samples}, " + "selecting random configurations for the rest." + ) + atoms_list_pt_minus_filtered = [ + x for x in atoms_list_pt if x not in atoms_list_pt_filtered + ] + atoms_list_pt_random_inds = np.random.choice( + list(range(len(atoms_list_pt_minus_filtered))), + args.num_samples - len(atoms_list_pt_filtered), + replace=False, + ) + atoms_list_pt = atoms_list_pt_filtered + [ + atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds + ] + else: + atoms_list_pt = atoms_list_pt_filtered + + else: + atoms_list_pt = ase.io.read(args.configs_pt, index=":") + if args.descriptors is not None: + logging.info( + f"Loading descriptors for the pretraining set from {args.descriptors}" + ) + descriptors = np.load(args.descriptors, allow_pickle=True) + for i, atoms in enumerate(atoms_list_pt): + atoms.info["mace_descriptors"] = descriptors[i] + + if args.num_samples is not None and args.num_samples < len(atoms_list_pt): + if args.subselect == "fps": + if args.descriptors is None: + logging.info("Calculating descriptors for the pretraining set") + calculate_descriptors(atoms_list_pt, calc) + descriptors_list = [ + atoms.info["mace_descriptors"] for atoms in atoms_list_pt + ] + logging.info( + f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" + ) + np.save( + args.output.replace(".xyz", "_descriptors.npy"), descriptors_list + ) + logging.info("Selecting configurations using Farthest Point Sampling") + try: + fps_pt = FPS(atoms_list_pt, args.num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + except Exception as e: # pylint: disable=W0703 + logging.error( + f"FPS failed, selecting random configurations instead: {e}" + ) + idx_pt = np.random.choice( + list(range(len(atoms_list_pt))), args.num_samples, replace=False + ) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + else: + logging.info("Selecting random configurations") + idx_pt = np.random.choice( + list(range(len(atoms_list_pt))), args.num_samples, replace=False + ) + atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + for atoms in atoms_list_pt: + # del atoms.info["mace_descriptors"] + atoms.info["pretrained"] = True + atoms.info["config_weight"] = args.weight_pt + atoms.info["mace_descriptors"] = None + if args.head_pt is not None: + atoms.info["head"] = args.head_pt + + logging.info("Saving the selected configurations") + ase.io.write(args.output, atoms_list_pt, format="extxyz") + logging.info("Saving a combined XYZ file") + for atoms in atoms_list_ft: + atoms.info["pretrained"] = False + atoms.info["config_weight"] = args.weight_ft + atoms.info["mace_descriptors"] = None + if args.head_ft is not None: + atoms.info["head"] = args.head_ft + atoms_fps_pt_ft = atoms_list_pt + atoms_list_ft + ase.io.write( + args.output.replace(".xyz", "_combined.xyz"), atoms_fps_pt_ft, format="extxyz" + ) + + +def main(): + args = parse_args() + select_samples(args) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index f98c7a04..7acafaa6 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -12,32 +12,43 @@ import os from copy import deepcopy from pathlib import Path -from typing import Optional +from typing import List, Optional -import numpy as np import torch.distributed import torch.nn.functional -from e3nn import o3 from e3nn.util import jit from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.swa_utils import SWALR, AveragedModel +from torch.utils.data import ConcatDataset from torch_ema import ExponentialMovingAverage import mace -from mace import data, modules, tools +from mace import data, tools from mace.calculators.foundations_models import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations +from mace.tools.model_script_utils import configure_model +from mace.tools.multihead_tools import ( + HeadConfig, + assemble_mp_data, + dict_head_to_dataclass, + prepare_default_head, +) from mace.tools.scripts_utils import ( LRScheduler, convert_to_json_format, create_error_table, + dict_to_array, extract_config_mace_model, get_atomic_energies, + get_avg_num_neighbors, get_config_type_weights, get_dataset_from_xyz, get_files_with_suffix, + get_loss_fn, + get_optimizer, + get_params_options, + get_swa, print_git_commit, + setup_wandb, ) from mace.tools.slurm_distributed import DistributedEnvironment from mace.tools.utils import AtomicNumberTable @@ -55,8 +66,16 @@ def run(args: argparse.Namespace) -> None: """ This script runs the training/fine tuning for mace """ - args, input_log_messages = tools.check_args(args) tag = tools.get_tag(name=args.name, seed=args.seed) + args, input_log_messages = tools.check_args(args) + + if args.device == "xpu": + try: + import intel_extension_for_pytorch as ipex + except ImportError as e: + raise ImportError( + "Error: Intel extension for PyTorch not found, but XPU device was specified" + ) from e if args.distributed: try: distr_env = DistributedEnvironment() @@ -93,7 +112,12 @@ def run(args: argparse.Namespace) -> None: tools.set_default_dtype(args.default_dtype) device = tools.init_device(args.device) commit = print_git_commit() + model_foundation: Optional[torch.nn.Module] = None if args.foundation_model is not None: + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" if args.foundation_model in ["small", "medium", "large"]: logging.info( f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." @@ -116,154 +140,314 @@ def run(args: argparse.Namespace) -> None: ) model_foundation = calc.models[0] else: - model_foundation = torch.load(args.foundation_model, map_location=device) + model_foundation = torch.load( + args.foundation_model, map_location=args.device + ) logging.info( f"Using foundation model {args.foundation_model} as initial checkpoint." ) args.r_max = model_foundation.r_max.item() + else: + args.multiheads_finetuning = False - if args.statistics_file is not None: - with open(args.statistics_file, "r") as f: # pylint: disable=W1514 - statistics = json.load(f) - logging.info("Using statistics json file") - args.r_max = ( - statistics["r_max"] if args.foundation_model is None else args.r_max - ) - args.atomic_numbers = statistics["atomic_numbers"] - args.mean = statistics["mean"] - args.std = statistics["std"] - args.avg_num_neighbors = statistics["avg_num_neighbors"] - args.compute_avg_num_neighbors = False - args.E0s = statistics["atomic_energies"] + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + else: + args.heads = prepare_default_head(args) - logging.info("") logging.info("===========LOADING INPUT DATA===========") - # Data preparation - if args.train_file.endswith(".xyz"): - if args.valid_file is not None: - assert args.valid_file.endswith( - ".xyz" - ), "valid_file if given must be same format as train_file" - config_type_weights = get_config_type_weights(args.config_type_weights) - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - energy_key=args.energy_key, - forces_key=args.forces_key, - stress_key=args.stress_key, - virials_key=args.virials_key, - dipole_key=args.dipole_key, - charges_key=args.charges_key, - keep_isolated_atoms=args.keep_isolated_atoms, + heads = list(args.heads.keys()) + logging.info(f"Using heads: {heads}") + head_configs: List[HeadConfig] = [] + for head, head_args in args.heads.items(): + logging.info(f"============= Processing head {head} ===========") + head_config = dict_head_to_dataclass(head_args, head, args) + if head_config.statistics_file is not None: + with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 + statistics = json.load(f) + logging.info("Using statistics json file") + head_config.r_max = ( + statistics["r_max"] if args.foundation_model is None else args.r_max + ) + head_config.atomic_numbers = statistics["atomic_numbers"] + head_config.mean = statistics["mean"] + head_config.std = statistics["std"] + head_config.avg_num_neighbors = statistics["avg_num_neighbors"] + head_config.compute_avg_num_neighbors = False + if isinstance(statistics["atomic_energies"], str) and statistics[ + "atomic_energies" + ].endswith(".json"): + with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: + atomic_energies = json.load(f) + head_config.E0s = atomic_energies + head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) + else: + head_config.E0s = statistics["atomic_energies"] + head_config.atomic_energies_dict = ast.literal_eval( + statistics["atomic_energies"] + ) + + # Data preparation + if head_config.train_file.endswith(".xyz"): + if head_config.valid_file is not None: + assert head_config.valid_file.endswith( + ".xyz" + ), "valid_file if given must be same format as train_file" + config_type_weights = get_config_type_weights( + head_config.config_type_weights + ) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=head_config.train_file, + valid_path=head_config.valid_file, + valid_fraction=head_config.valid_fraction, + config_type_weights=config_type_weights, + test_path=head_config.test_file, + seed=args.seed, + energy_key=head_config.energy_key, + forces_key=head_config.forces_key, + stress_key=head_config.stress_key, + virials_key=head_config.virials_key, + dipole_key=head_config.dipole_key, + charges_key=head_config.charges_key, + head_name=head_config.head_name, + keep_isolated_atoms=head_config.keep_isolated_atoms, + ) + head_config.collections = collections + head_config.atomic_energies_dict = atomic_energies_dict + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," + ) + head_configs.append(head_config) + + if all(head_config.train_file.endswith(".xyz") for head_config in head_configs): + size_collections_train = sum( + len(head_config.collections.train) for head_config in head_configs + ) + size_collections_valid = sum( + len(head_config.collections.valid) for head_config in head_configs ) - if len(collections.train) < args.batch_size: + if size_collections_train < args.batch_size: logging.error( - f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})" + f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" ) - if len(collections.valid) < args.valid_batch_size: + if size_collections_valid < args.valid_batch_size: logging.warning( - f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})" + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" ) - args.valid_batch_size = len(collections.valid) - else: - atomic_energies_dict = None + if args.multiheads_finetuning: + logging.info( + "==================Using multiheads finetuning mode==================" + ) + args.loss = "universal" + if ( + args.foundation_model in ["small", "medium", "large"] + or "mp" in args.foundation_model + or args.pt_train_file is None + ): + logging.info( + "Using foundation model for multiheads finetuning with Materials Project data" + ) + heads = list(dict.fromkeys(["pt_head"] + heads)) + head_config_pt = HeadConfig( + head_name="pt_head", + E0s="foundation", + statistics_file=args.statistics_file, + compute_avg_num_neighbors=False, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, + ) + collections = assemble_mp_data(args, tag, head_configs) + head_config_pt.collections = collections + head_config_pt.train_file = f"mp_finetuning-{tag}.xyz" + head_configs.append(head_config_pt) + else: + logging.info( + f"Using foundation model for multiheads finetuning with {args.pt_train_file}" + ) + heads = list(dict.fromkeys(["pt_head"] + heads)) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=args.pt_train_file, + valid_path=args.pt_valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + head_name="pt_head", + keep_isolated_atoms=args.keep_isolated_atoms, + ) + head_config_pt = HeadConfig( + head_name="pt_head", + train_file=args.pt_train_file, + valid_file=args.pt_valid_file, + E0s="foundation", + statistics_file=args.statistics_file, + valid_fraction=args.valid_fraction, + config_type_weights=None, + energy_key=args.energy_key, + forces_key=args.forces_key, + stress_key=args.stress_key, + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + collections=collections, + avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, + compute_avg_num_neighbors=False, + ) + head_config_pt.collections = collections + head_configs.append(head_config_pt) + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + ) # Atomic number table # yapf: disable - if args.atomic_numbers is None: - assert args.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) - else: - if args.statistics_file is None: - logging.info("Using atomic numbers from command line argument") + for head_config in head_configs: + if head_config.atomic_numbers is None: + assert head_config.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input" + z_table_head = tools.get_atomic_number_table_from_zs( + z + for configs in (head_config.collections.train, head_config.collections.valid) + for config in configs + for z in config.atomic_numbers + ) + head_config.atomic_numbers = z_table_head.zs + head_config.z_table = z_table_head else: - logging.info("Using atomic numbers from statistics file") - zs_list = ast.literal_eval(args.atomic_numbers) - assert isinstance(zs_list, list) - z_table = tools.get_atomic_number_table_from_zs(zs_list) - # yapf: enable + if head_config.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(head_config.atomic_numbers) + assert isinstance(zs_list, list) + z_table_head = tools.AtomicNumberTable(zs_list) + head_config.atomic_numbers = zs_list + head_config.z_table = z_table_head + # yapf: enable + all_atomic_numbers = set() + for head_config in head_configs: + all_atomic_numbers.update(head_config.atomic_numbers) + z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) logging.info(f"Atomic Numbers used: {z_table.zs}") - if atomic_energies_dict is None or len(atomic_energies_dict) == 0: - if args.E0s.lower() == "foundation": - assert args.foundation_model is not None - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - atomic_energies_dict = { - z: model_foundation.atomic_energies_fn.atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } - logging.info( - f"Using Atomic Energies from foundation model [z, eV]: {', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table_foundation.zs])}" - ) - else: - if args.train_file.endswith(".xyz"): - atomic_energies_dict = get_atomic_energies( - args.E0s, collections.train, z_table + # Atomic energies + atomic_energies_dict = {} + for head_config in head_configs: + if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + if head_config.train_file.endswith(".xyz") and head_config.E0s.lower() != "foundation": + atomic_energies_dict[head_config.head_name] = get_atomic_energies( + head_config.E0s, head_config.collections.train, head_config.z_table ) + elif head_config.E0s.lower() == "foundation": + assert args.foundation_model is not None + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + atomic_energies_dict[head_config.head_name] = { + z: model_foundation.atomic_energies_fn.atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } else: - atomic_energies_dict = get_atomic_energies(args.E0s, None, z_table) + atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) + else: + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + + # Atomic energies for multiheads finetuning + if args.multiheads_finetuning: + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + atomic_energies_dict["pt_head"] = { + z: model_foundation.atomic_energies_fn.atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } if args.model == "AtomicDipolesMACE": atomic_energies = None dipole_only = True - compute_dipole = True - compute_energy = False + args.compute_dipole = True + args.compute_energy = False args.compute_forces = False - compute_virials = False + args.compute_virials = False args.compute_stress = False else: dipole_only = False if args.model == "EnergyDipolesMACE": - compute_dipole = True - compute_energy = True + args.compute_dipole = True + args.compute_energy = True args.compute_forces = True - compute_virials = False + args.compute_virials = False args.compute_stress = False else: - compute_energy = True - compute_dipole = False - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) - logging.info( - f"Atomic Energies used (z: eV): {{{', '.join([f'{z}: {atomic_energies_dict[z]}' for z in z_table.zs])}}}" - ) + args.compute_energy = True + args.compute_dipole = False + # atomic_energies: np.ndarray = np.array( + # [atomic_energies_dict[z] for z in z_table.zs] + # ) + atomic_energies = dict_to_array(atomic_energies_dict, heads) + for head_config in head_configs: + logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") + + + valid_sets = {head: [] for head in heads} + train_sets = {head: [] for head in heads} + for head_config in head_configs: + if head_config.train_file.endswith(".xyz"): + train_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.train + ] + valid_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.valid + ] - if args.train_file.endswith(".xyz"): - train_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.train - ] - valid_set = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in collections.valid - ] - elif args.train_file.endswith(".h5"): - train_set = data.HDF5Dataset(args.train_file, r_max=args.r_max, z_table=z_table) - valid_set = data.HDF5Dataset(args.valid_file, r_max=args.r_max, z_table=z_table) - else: # This case would be for when the file path is to a directory of multiple .h5 files - train_set = data.dataset_from_sharded_hdf5( - args.train_file, r_max=args.r_max, z_table=z_table - ) - valid_set = data.dataset_from_sharded_hdf5( - args.valid_file, r_max=args.r_max, z_table=z_table + elif head_config.train_file.endswith(".h5"): + train_sets[head_config.head_name] = data.HDF5Dataset( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + valid_sets[head_config.head_name] = data.HDF5Dataset( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: # This case would be for when the file path is to a directory of multiple .h5 files + train_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.train_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + valid_sets[head_config.head_name] = data.dataset_from_sharded_hdf5( + head_config.valid_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + train_loader_head = torch_geometric.dataloader.DataLoader( + dataset=train_sets[head_config.head_name], + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), ) - + head_config.train_loader = train_loader_head + # concatenate all the trainsets + train_set = ConcatDataset([train_sets[head] for head in heads]) train_sampler, valid_sampler = None, None if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( @@ -274,14 +458,17 @@ def run(args: argparse.Namespace) -> None: drop_last=True, seed=args.seed, ) - valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) + valid_samplers = {} + for head, valid_set in valid_sets.items(): + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers[head] = valid_sampler train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, batch_size=args.batch_size, @@ -292,273 +479,26 @@ def run(args: argparse.Namespace) -> None: num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), ) - valid_loader = torch_geometric.dataloader.DataLoader( - dataset=valid_set, - batch_size=args.valid_batch_size, - sampler=valid_sampler, - shuffle=False, - drop_last=False, - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - logging.info("") - logging.info("===========MODEL DETAILS===========") - if args.loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) - elif args.loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) - elif args.loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - virials_weight=args.virials_weight, - ) - elif args.loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - ) - elif args.loss == "huber": - loss_fn = modules.WeightedHuberEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "universal": - loss_fn = modules.UniversalLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=args.dipole_weight, - ) - elif args.loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - dipole_weight=args.dipole_weight, - ) - else: - # Unweighted Energy and Forces loss by default - loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) - - if args.compute_avg_num_neighbors: - avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) - if args.distributed: - num_graphs = torch.tensor(len(train_loader.dataset)).to(device) - num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) - torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce( - num_neighbors, op=torch.distributed.ReduceOp.SUM - ) - args.avg_num_neighbors = (num_neighbors / num_graphs).item() - else: - args.avg_num_neighbors = avg_num_neighbors - if args.avg_num_neighbors < 2 or args.avg_num_neighbors > 100: - logging.warning( - f"Unusual average number of neighbors: {args.avg_num_neighbors:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {args.avg_num_neighbors:.1f}") - - # Selecting outputs - compute_virials = False - if args.loss in ("stress", "virials", "huber", "universal"): - compute_virials = True - args.compute_stress = True - if "MAE" in args.error_table: - args.error_table = "PerAtomMAEstressvirials" - else: - args.error_table = "PerAtomRMSEstressvirials" - - output_args = { - "energy": compute_energy, - "forces": args.compute_forces, - "virials": compute_virials, - "stress": args.compute_stress, - "dipoles": compute_dipole, - } - - logging.info( - f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" - ) - - if args.scaling == "no_scaling": - args.std = 1.0 - logging.info("No scaling selected") - elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": - args.mean, args.std = modules.scaling_classes[args.scaling]( - train_loader, atomic_energies - ) - # Build model - if args.foundation_model is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading FOUNDATION model") - model_config_foundation = extract_config_mace_model(model_foundation) - model_config_foundation["atomic_numbers"] = z_table.zs - model_config_foundation["num_elements"] = len(z_table) - args.max_L = model_config_foundation["hidden_irreps"].lmax - args.num_channels = list( - {irrep.mul for irrep in o3.Irreps(model_config_foundation["hidden_irreps"])} - )[0] - model_config_foundation["atomic_inter_shift"] = ( - model_foundation.scale_shift.shift.item() - ) - model_config_foundation["atomic_inter_scale"] = ( - model_foundation.scale_shift.scale.item() - ) - model_config_foundation["atomic_energies"] = atomic_energies - args.model = "FoundationMACE" - model_config = model_config_foundation # pylint - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({model_config_foundation['hidden_irreps']})" - ) - logging.info( - f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" - ) - logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} Å (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} Å)" - ) - logging.info( - f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" - ) - else: - logging.info("Building model") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" - ) - logging.info( - f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" - ) - logging.info( - f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" - ) - logging.info( - f"Radial cutoff: {args.r_max} Å (total receptive field for each atom: {args.r_max * args.num_interactions} Å)" - ) - logging.info( - f"Distance transform for radial basis functions: {args.distance_transform}" - ) - model_config = dict( - r_max=args.r_max, - num_bessel=args.num_radial_basis, - num_polynomial_cutoff=args.num_cutoff_basis, - max_ell=args.max_ell, - interaction_cls=modules.interaction_classes[args.interaction], - num_interactions=args.num_interactions, - num_elements=len(z_table), - hidden_irreps=o3.Irreps(args.hidden_irreps), - atomic_energies=atomic_energies, - avg_num_neighbors=args.avg_num_neighbors, - atomic_numbers=z_table.zs, + valid_loaders = {heads[i]: None for i in range(len(heads))} + if not isinstance(valid_sets, dict): + valid_sets = {"Default": valid_sets} + for head, valid_set in valid_sets.items(): + valid_loaders[head] = torch_geometric.dataloader.DataLoader( + dataset=valid_set, + batch_size=args.valid_batch_size, + sampler=valid_samplers[head] if args.distributed else None, + shuffle=False, + drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), ) - model: torch.nn.Module - - if args.model == "MACE": - model = modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=0.0, - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - ) - elif args.model == "ScaleShiftMACE": - model = modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - ) - elif args.model == "FoundationMACE": - model = modules.ScaleShiftMACE(**model_config_foundation) - elif args.model == "ScaleShiftBOTNet": - model = modules.ScaleShiftBOTNet( - **model_config, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - ) - elif args.model == "BOTNet": - model = modules.BOTNet( - **model_config, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - elif args.model == "AtomicDipolesMACE": - # std_df = modules.scaling_classes["rms_dipoles_scaling"](train_loader) - assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" - assert ( - args.error_table == "DipoleRMSE" - ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" - model = modules.AtomicDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - # dipole_scale=1, - # dipole_shift=0, - ) - elif args.model == "EnergyDipolesMACE": - # std_df = modules.scaling_classes["rms_dipoles_scaling"](train_loader) - assert ( - args.loss == "energy_forces_dipole" - ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" - assert ( - args.error_table == "EnergyDipoleRMSE" - ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" - model = modules.EnergyDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - else: - raise RuntimeError(f"Unknown model: '{args.model}'") + loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) - if args.foundation_model is not None: - model = load_foundations( - model, - model_foundation, - z_table, - load_readout=True, - max_L=args.max_L, - ) + # Model + model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table) model.to(device) logging.debug(model) @@ -576,62 +516,12 @@ def run(args: argparse.Namespace) -> None: logging.info(loss_fn) # Optimizer - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.interactions.named_parameters(): - if "linear.weight" in name or "skip_tp_full.weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.node_embedding.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": args.weight_decay, - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "products", - "params": model.products.parameters(), - "weight_decay": args.weight_decay, - }, - { - "name": "readouts", - "params": model.readouts.parameters(), - "weight_decay": 0.0, - }, - ], - lr=args.lr, - amsgrad=args.amsgrad, - betas=(args.beta, 0.999), - ) - + param_options = get_params_options(args, model) optimizer: torch.optim.Optimizer - if args.optimizer == "adamw": - optimizer = torch.optim.AdamW(**param_options) - elif args.optimizer == "schedulefree": - try: - from schedulefree import adamw_schedulefree - except ImportError as exc: - raise ImportError( - "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" - ) from exc - _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) - else: - optimizer = torch.optim.Adam(**param_options) - + optimizer = get_optimizer(args, param_options) + if args.device == "xpu": + logging.info("Optimzing model and optimzier for XPU") + model, optimizer = ipex.optimize(model, optimizer=optimizer) logger = tools.MetricsLogger( directory=args.results_dir, tag=tag + "_train" ) # pylint: disable=E1123 @@ -641,50 +531,7 @@ def run(args: argparse.Namespace) -> None: swa: Optional[tools.SWAContainer] = None swas = [False] if args.swa: - assert dipole_only is False, "Stage Two for dipole fitting not implemented" - swas.append(True) - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - logging.info( - f"Stage Two will start after {args.start_swa} epochs with loss function:" - ) - if args.loss == "forces_only": - raise ValueError("Can not select Stage Two with forces only loss.") - if args.loss == "virials": - loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - virials_weight=args.swa_virials_weight, - ) - elif args.loss == "stress": - loss_fn_energy = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - ) - elif args.loss == "energy_forces_dipole": - loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( - args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - dipole_weight=args.swa_dipole_weight, - ) - else: - loss_fn_energy = modules.WeightedEnergyForcesLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - ) - logging.info(loss_fn_energy) - swa = tools.SWAContainer( - model=AveragedModel(model), - scheduler=SWALR( - optimizer=optimizer, - swa_lr=args.swa_lr, - anneal_epochs=1, - anneal_strategy="linear", - ), - start=args.start_swa, - loss_fn=loss_fn_energy, - ) + swa, swas = get_swa(args, model, optimizer, swas, dipole_only) checkpoint_handler = tools.CheckpointHandler( directory=args.checkpoints_dir, @@ -718,22 +565,7 @@ def run(args: argparse.Namespace) -> None: group["lr"] = args.lr if args.wandb: - logging.info("Using Weights and Biases for logging") - import wandb - - wandb_config = {} - args_dict = vars(args) - args_dict_json = json.dumps(args_dict) - for key in args.wandb_log_hypers: - wandb_config[key] = args_dict[key] - tools.init_wandb( - project=args.wandb_project, - entity=args.wandb_entity, - name=args.wandb_name, - config=wandb_config, - directory=args.wandb_dir, - ) - wandb.run.summary["params"] = args_dict_json + setup_wandb(args) if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) @@ -744,7 +576,7 @@ def run(args: argparse.Namespace) -> None: model=model, loss_fn=loss_fn, train_loader=train_loader, - valid_loader=valid_loader, + valid_loaders=valid_loaders, optimizer=optimizer, lr_scheduler=lr_scheduler, checkpoint_handler=checkpoint_handler, @@ -766,68 +598,85 @@ def run(args: argparse.Namespace) -> None: train_sampler=train_sampler, rank=rank, ) + logging.info("") logging.info("===========RESULTS===========") logging.info("Computing metrics for training, validation, and test sets") - all_data_loaders = { - "train": train_loader, - "valid": valid_loader, - } + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader test_sets = {} - if args.train_file.endswith(".xyz"): - for name, subset in collections.tests: - test_sets[name] = [ - data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max) - for config in subset - ] - elif not args.multi_processed_test: - test_files = get_files_with_suffix(args.test_dir, "_test.h5") - for test_file in test_files: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table - ) - else: - test_folders = glob(args.test_dir + "/*") - for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table - ) - - for test_name, test_set in test_sets.items(): - test_sampler = None - if args.distributed: - test_sampler = torch.utils.data.distributed.DistributedSampler( + stop_first_test = False + test_data_loader = {} + if all( + head_config.test_file == head_configs[0].test_file + for head_config in head_configs + ) and head_configs[0].test_file is not None: + stop_first_test = True + if all( + head_config.test_dir == head_configs[0].test_dir + for head_config in head_configs + ) and head_configs[0].test_dir is not None: + stop_first_test = True + for head_config in head_configs: + if head_config.train_file.endswith(".xyz"): + print(head_config.test_file) + for name, subset in head_config.collections.tests: + print(name) + test_sets[name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in subset + ] + if head_config.test_dir is not None: + if not args.multi_processed_test: + test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") + for test_file in test_files: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.HDF5Dataset( + test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: + test_folders = glob(head_config.test_dir + "/*") + for folder in test_folders: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.dataset_from_sharded_hdf5( + folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + for test_name, test_set in test_sets.items(): + print(test_name) + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( test_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, ) - try: - drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 - drop_last = False - test_loader = torch_geometric.dataloader.DataLoader( - test_set, - batch_size=args.valid_batch_size, - shuffle=(test_sampler is None), - drop_last=drop_last, - num_workers=args.num_workers, - pin_memory=args.pin_memory, - ) - all_data_loaders[test_name] = test_loader - - train_valid_data_loader = { - k: v for k, v in all_data_loaders.items() if k in ["train", "valid"] - } - test_data_loader = { - k: v for k, v in all_data_loaders.items() if k not in ["train", "valid"] - } + test_data_loader[test_name] = test_loader + if stop_first_test: + break for swa_eval in swas: epoch = checkpoint_handler.load_latest( @@ -846,8 +695,7 @@ def run(args: argparse.Namespace) -> None: for param in model.parameters(): param.requires_grad = False - - table_train = create_error_table( + table_train_valid = create_error_table( table_type=args.error_table, all_data_loaders=train_valid_data_loader, model=model_to_evaluate, @@ -857,18 +705,20 @@ def run(args: argparse.Namespace) -> None: device=device, distributed=args.distributed, ) - table_test = create_error_table( - table_type=args.error_table, - all_data_loaders=test_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - logging.info("Error-table on TRAIN and VALID:\n" + str(table_train)) - logging.info("Error-table on TEST:\n" + str(table_test)) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + + if test_data_loader: + table_test = create_error_table( + table_type=args.error_table, + all_data_loaders=test_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TEST:\n" + str(table_test)) if rank == 0: # Save entire model diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index edb91b14..cb4edd94 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -52,6 +52,7 @@ def __init__( unit_shifts: torch.Tensor, # [n_edges, 3] cell: Optional[torch.Tensor], # [3,3] weight: Optional[torch.Tensor], # [,] + head: Optional[torch.Tensor], # [,] energy_weight: Optional[torch.Tensor], # [,] forces_weight: Optional[torch.Tensor], # [,] stress_weight: Optional[torch.Tensor], # [,] @@ -72,6 +73,7 @@ def __init__( assert unit_shifts.shape[1] == 3 assert len(node_attrs.shape) == 2 assert weight is None or len(weight.shape) == 0 + assert head is None or len(head.shape) == 0 assert energy_weight is None or len(energy_weight.shape) == 0 assert forces_weight is None or len(forces_weight.shape) == 0 assert stress_weight is None or len(stress_weight.shape) == 0 @@ -93,6 +95,7 @@ def __init__( "cell": cell, "node_attrs": node_attrs, "weight": weight, + "head": head, "energy_weight": energy_weight, "forces_weight": forces_weight, "stress_weight": stress_weight, @@ -108,9 +111,15 @@ def __init__( @classmethod def from_config( - cls, config: Configuration, z_table: AtomicNumberTable, cutoff: float + cls, + config: Configuration, + z_table: AtomicNumberTable, + cutoff: float, + heads: Optional[list] = None, ) -> "AtomicData": - edge_index, shifts, unit_shifts = get_neighborhood( + if heads is None: + heads = ["default"] + edge_index, shifts, unit_shifts, cell = get_neighborhood( positions=config.positions, cutoff=cutoff, pbc=config.pbc, cell=config.cell ) indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) @@ -118,10 +127,14 @@ def from_config( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), num_classes=len(z_table), ) + try: + head = torch.tensor(heads.index(config.head), dtype=torch.long) + except ValueError: + head = torch.tensor(len(heads) - 1, dtype=torch.long) cell = ( - torch.tensor(config.cell, dtype=torch.get_default_dtype()) - if config.cell is not None + torch.tensor(cell, dtype=torch.get_default_dtype()) + if cell is not None else torch.tensor( 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() ).view(3, 3) @@ -200,6 +213,7 @@ def from_config( cell=cell, node_attrs=one_hot, weight=weight, + head=head, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, diff --git a/mace/data/hdf5_dataset.py b/mace/data/hdf5_dataset.py index 5057fd7f..477ccd3f 100644 --- a/mace/data/hdf5_dataset.py +++ b/mace/data/hdf5_dataset.py @@ -66,17 +66,24 @@ def __getitem__(self, index): pbc=unpack_value(subgrp["pbc"][()]), cell=unpack_value(subgrp["cell"][()]), ) + if config.head is None: + config.head = self.kwargs.get("head") atomic_data = AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), ) return atomic_data -def dataset_from_sharded_hdf5(files: List, z_table: AtomicNumberTable, r_max: float): +def dataset_from_sharded_hdf5( + files: List, z_table: AtomicNumberTable, r_max: float, **kwargs +): files = glob(files + "/*") datasets = [] for file in files: - datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max)) + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) full_dataset = ConcatDataset(datasets) return full_dataset diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 293576af..21296fa6 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -27,18 +27,18 @@ def get_neighborhood( max_positions = np.max(np.absolute(positions)) + 1 # Extend cell in non-periodic directions # For models with more than 5 layers, the multiplicative constant needs to be increased. - temp_cell = np.copy(cell) + # temp_cell = np.copy(cell) if not pbc_x: - temp_cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] if not pbc_y: - temp_cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] if not pbc_z: - temp_cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] sender, receiver, unit_shifts = neighbour_list( quantities="ijS", pbc=pbc, - cell=temp_cell, + cell=cell, positions=positions, cutoff=cutoff, # self_interaction=True, # we want edges from atom to itself in different periodic images @@ -63,4 +63,4 @@ def get_neighborhood( # D = positions[j]-positions[i]+S.dot(cell) shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - return edge_index, shifts, unit_shifts + return edge_index, shifts, unit_shifts, cell diff --git a/mace/data/utils.py b/mace/data/utils.py index 78e3e76f..bb8e5448 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -47,6 +47,7 @@ class Configuration: stress_weight: float = 1.0 # weight of config stress in loss virials_weight: float = 1.0 # weight of config virial in loss config_type: Optional[str] = DEFAULT_CONFIG_TYPE # config_type of config + head: Optional[str] = "Default" # head used to compute the config Configurations = List[Configuration] @@ -91,7 +92,8 @@ def config_from_atoms_list( virials_key="REF_virials", dipole_key="REF_dipole", charges_key="REF_charges", - config_type_weights: Dict[str, float] = None, + head_key="head", + config_type_weights: Optional[Dict[str, float]] = None, ) -> Configurations: """Convert list of ase.Atoms into Configurations""" if config_type_weights is None: @@ -108,6 +110,7 @@ def config_from_atoms_list( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, config_type_weights=config_type_weights, ) ) @@ -122,7 +125,8 @@ def config_from_atoms( virials_key="REF_virials", dipole_key="REF_dipole", charges_key="REF_charges", - config_type_weights: Dict[str, float] = None, + head_key="head", + config_type_weights: Optional[Dict[str, float]] = None, ) -> Configuration: """Convert ase.Atoms to Configuration""" if config_type_weights is None: @@ -149,6 +153,8 @@ def config_from_atoms( stress_weight = atoms.info.get("config_stress_weight", 1.0) virials_weight = atoms.info.get("config_virials_weight", 1.0) + head = atoms.info.get(head_key, "Default") + # fill in missing quantities but set their weight to 0.0 if energy is None: energy = 0.0 @@ -176,6 +182,7 @@ def config_from_atoms( dipole=dipole, charges=charges, weight=weight, + head=head, energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight, @@ -193,11 +200,12 @@ def test_config_types( test_by_ct = [] all_cts = [] for conf in test_configs: - if conf.config_type not in all_cts: - all_cts.append(conf.config_type) - test_by_ct.append((conf.config_type, [conf])) + config_type_name = conf.config_type + "_" + conf.head + if config_type_name not in all_cts: + all_cts.append(config_type_name) + test_by_ct.append((config_type_name, [conf])) else: - ind = all_cts.index(conf.config_type) + ind = all_cts.index(config_type_name) test_by_ct[ind][1].append(conf) return test_by_ct @@ -211,6 +219,8 @@ def load_from_xyz( virials_key: str = "REF_virials", dipole_key: str = "REF_dipole", charges_key: str = "REF_charges", + head_key: str = "head", + head_name: str = "Default", extract_atomic_energies: bool = False, keep_isolated_atoms: bool = False, ) -> Tuple[Dict[int, float], Configurations]: @@ -255,6 +265,7 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): + atoms.info[head_key] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -286,6 +297,7 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, ) return atomic_energies_dict, configs @@ -342,6 +354,7 @@ def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges + grp["head"] = data.head def save_AtomicData_to_HDF5(data, i, h5_file) -> None: @@ -364,6 +377,7 @@ def save_AtomicData_to_HDF5(data, i, h5_file) -> None: grp["virials"] = data.virials grp["dipole"] = data.dipole grp["charges"] = data.charges + grp["head"] = data.head def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: @@ -377,6 +391,7 @@ def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> N subgroup["forces"] = write_value(config.forces) subgroup["stress"] = write_value(config.stress) subgroup["virials"] = write_value(config.virials) + subgroup["head"] = write_value(config.head) subgroup["dipole"] = write_value(config.dipole) subgroup["charges"] = write_value(config.charges) subgroup["cell"] = write_value(config.cell) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index e8645a8e..34539b0b 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -17,6 +17,7 @@ from .irreps_tools import ( linear_out_irreps, + mask_head, reshape_irreps, tp_out_irreps_with_instructions, ) @@ -46,11 +47,15 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps): + def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=o3.Irreps("0e")) + self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + def forward( + self, + x: torch.Tensor, + heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -58,19 +63,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [... @compile_mode("script") class NonLinearReadoutBlock(torch.nn.Module): def __init__( - self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, gate: Optional[Callable] + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps = o3.Irreps("0e"), + num_heads: int = 1, ): super().__init__() self.hidden_irreps = MLP_irreps + self.num_heads = num_heads self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=o3.Irreps("0e") - ) + self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + def forward( + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] x = self.non_linearity(self.linear_1(x)) - return self.linear_2(x) # [n_nodes, 1] + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) + return self.linear_2(x) # [n_nodes, len(heads)] @compile_mode("script") @@ -133,20 +147,25 @@ class AtomicEnergiesBlock(torch.nn.Module): def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): super().__init__() - assert len(atomic_energies.shape) == 1 + # assert len(atomic_energies.shape) == 1 self.register_buffer( "atomic_energies", torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), - ) # [n_elements, ] + ) # [n_elements, n_heads] def forward( self, x: torch.Tensor # one-hot of elements [..., n_elements] ) -> torch.Tensor: # [..., ] - return torch.matmul(x, self.atomic_energies) + return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) def __repr__(self): - formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) + formatted_energies = ", ".join( + [ + "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" + for group in torch.atleast_2d(self.atomic_energies) + ] + ) return f"{self.__class__.__name__}(energies=[{formatted_energies}])" @@ -602,7 +621,7 @@ def _setup(self) -> None: input_dim = self.edge_feats_irreps.num_irreps self.conv_tp_weights = nn.FullyConnectedNet( [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, + torch.nn.functional.silu, # gate ) # Linear @@ -743,16 +762,28 @@ class ScaleShiftBlock(torch.nn.Module): def __init__(self, scale: float, shift: float): super().__init__() self.register_buffer( - "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) + "scale", + torch.tensor(scale, dtype=torch.get_default_dtype()), ) self.register_buffer( - "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) + "shift", + torch.tensor(shift, dtype=torch.get_default_dtype()), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.scale * x + self.shift + def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: + return ( + torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] + ) def __repr__(self): - return ( - f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" + formatted_scale = ( + ", ".join([f"{x:.4f}" for x in self.scale]) + if self.scale.numel() > 1 + else f"{self.scale.item():.4f}" + ) + formatted_shift = ( + ", ".join([f"{x:.4f}" for x in self.shift]) + if self.shift.numel() > 1 + else f"{self.shift.item():.4f}" ) + return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 642f3fa8..b0960193 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -84,3 +84,11 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: field = field.reshape(batch, mul, d) out.append(field) return torch.cat(out, dim=-1) + + +def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: + mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) + idx = torch.arange(mask.shape[0], device=x.device) + mask[idx, :, head] = 1 + mask = mask.permute(0, 2, 1).reshape(x.shape) + return x * mask diff --git a/mace/modules/loss.py b/mace/modules/loss.py index b3421ef5..a7e28c55 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -114,34 +114,34 @@ def conditional_mse_forces(ref: Batch, pred: TensorDict) -> torch.Tensor: def conditional_huber_forces( - ref: Batch, pred: TensorDict, huber_delta: float + ref_forces: Batch, pred_forces: TensorDict, huber_delta: float ) -> torch.Tensor: # Define the multiplication factors for each condition factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) # Apply multiplication factors based on conditions - c1 = torch.norm(ref["forces"], dim=-1) < 100 - c2 = (torch.norm(ref["forces"], dim=-1) >= 100) & ( - torch.norm(ref["forces"], dim=-1) < 200 + c1 = torch.norm(ref_forces, dim=-1) < 100 + c2 = (torch.norm(ref_forces, dim=-1) >= 100) & ( + torch.norm(ref_forces, dim=-1) < 200 ) - c3 = (torch.norm(ref["forces"], dim=-1) >= 200) & ( - torch.norm(ref["forces"], dim=-1) < 300 + c3 = (torch.norm(ref_forces, dim=-1) >= 200) & ( + torch.norm(ref_forces, dim=-1) < 300 ) c4 = ~(c1 | c2 | c3) - se = torch.zeros_like(pred["forces"]) + se = torch.zeros_like(pred_forces) se[c1] = torch.nn.functional.huber_loss( - ref["forces"][c1], pred["forces"][c1], reduction="none", delta=factors[0] + ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] ) se[c2] = torch.nn.functional.huber_loss( - ref["forces"][c2], pred["forces"][c2], reduction="none", delta=factors[1] + ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] ) se[c3] = torch.nn.functional.huber_loss( - ref["forces"][c3], pred["forces"][c3], reduction="none", delta=factors[2] + ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] ) se[c4] = torch.nn.functional.huber_loss( - ref["forces"][c4], pred["forces"][c4], reduction="none", delta=factors[3] + ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] ) return torch.mean(se) @@ -273,12 +273,28 @@ def __init__( def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ] + configs_energy_weight = ref.energy_weight # [n_graphs, ] + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) return ( self.energy_weight - * self.huber_loss(ref["energy"] / num_atoms, pred["energy"] / num_atoms) + * self.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + ) + self.forces_weight - * conditional_huber_forces(ref, pred, huber_delta=self.huber_delta) - + self.stress_weight * self.huber_loss(ref["stress"], pred["stress"]) + * conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ) + + self.stress_weight + * self.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + ) ) def __repr__(self): diff --git a/mace/modules/models.py b/mace/modules/models.py index 3e5cb662..c0d8ab43 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -61,6 +61,7 @@ def __init__( distance_transform: str = "None", radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", + heads: Optional[List[str]] = None, ): super().__init__() self.register_buffer( @@ -72,6 +73,9 @@ def __init__( self.register_buffer( "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) ) + if heads is None: + heads = ["default"] + self.heads = heads if isinstance(correlation, int): correlation = [correlation] * num_interactions # Embedding @@ -131,7 +135,9 @@ def __init__( self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + ) for i in range(num_interactions - 1): if i == num_interactions - 2: @@ -161,10 +167,18 @@ def __init__( self.products.append(prod) if i == num_interactions - 2: self.readouts.append( - NonLinearReadoutBlock(hidden_irreps_out, MLP_irreps, gate) + NonLinearReadoutBlock( + hidden_irreps_out, + (len(heads) * MLP_irreps).simplify(), + gate, + o3.Irreps(f"{len(heads)}x0e"), + len(heads), + ) ) else: - self.readouts.append(LinearReadoutBlock(hidden_irreps)) + self.readouts.append( + LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + ) def forward( self, @@ -179,7 +193,13 @@ def forward( # Setup data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) + num_atoms_arange = torch.arange(data["positions"].shape[0]) num_graphs = data["ptr"].numel() - 1 + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -200,10 +220,12 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, n_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) vectors, lengths = get_edge_vectors_and_lengths( @@ -246,13 +268,17 @@ def forward( node_attrs=data["node_attrs"], ) node_feats_list.append(node_feats) - node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = readout(node_feats, node_heads)[ + num_atoms_arange, node_heads + ] # [n_nodes, len(heads)] energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + src=node_energies, + index=data["batch"], + dim=0, + dim_size=num_graphs, ) # [n_graphs,] energies.append(energy) node_energies_list.append(node_energies) - # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) @@ -315,6 +341,12 @@ def forward( data["positions"].requires_grad_(True) data["node_attrs"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -335,10 +367,12 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, num_heads] # Embeddings node_feats = self.node_embedding(data["node_attrs"]) @@ -374,15 +408,17 @@ def forward( node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] ) node_feats_list.append(node_feats) - node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } + node_es_list.append( + readout(node_feats, node_heads)[num_atoms_arange, node_heads] + ) # {[n_nodes, ], } + # Concatenate node features node_feats_out = torch.cat(node_feats_list, dim=-1) - # print("node_es_list", node_es_list) # Sum over interactions node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + node_inter_es = self.scale_shift(node_inter_es, node_heads) # Sum over nodes in graph inter_e = scatter_sum( @@ -494,12 +530,15 @@ def __init__( def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Setup data.positions.requires_grad = True + num_atoms_arange = torch.arange(data.positions.shape[0]) # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs - ) # [n_graphs,] + ) # [n_graphs, n_heads] # Embeddings node_feats = self.node_embedding(data.node_attrs) @@ -557,9 +596,11 @@ def __init__( def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: # Setup data.positions.requires_grad = True - + num_atoms_arange = torch.arange(data.positions.shape[0]) # Atomic energies - node_e0 = self.atomic_energies_fn(data.node_attrs) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data.batch, dim=-1, dim_size=data.num_graphs ) # [n_graphs,] @@ -591,7 +632,7 @@ def forward(self, data: AtomicData, training=False) -> Dict[str, Any]: node_inter_es = torch.sum( torch.stack(node_es_list, dim=0), dim=0 ) # [n_nodes, ] - node_inter_es = self.scale_shift(node_inter_es) + node_inter_es = self.scale_shift(node_inter_es, data["head"][data["batch"]]) # Sum over nodes in graph inter_e = scatter_sum( @@ -951,6 +992,7 @@ def forward( data["node_attrs"].requires_grad_(True) data["positions"].requires_grad_(True) num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) displacement = torch.zeros( (num_graphs, 3, 3), dtype=data["positions"].dtype, @@ -971,7 +1013,9 @@ def forward( ) # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"]) + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] e0 = scatter_sum( src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs ) # [n_graphs,] diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 37fef1bb..5f08c819 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -14,7 +14,7 @@ from scipy.constants import c, e from mace.tools import to_numpy -from mace.tools.scatter import scatter_sum +from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum from mace.tools.torch_geometric.batch import Batch from .blocks import AtomicEnergiesBlock @@ -144,7 +144,6 @@ def compute_hessians_loop( forces: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: - hessian = [] for grad_elem in forces.view(-1): hess_row = torch.autograd.grad( @@ -181,7 +180,6 @@ def get_outputs( Optional[torch.Tensor], ]: if (compute_virials or compute_stress) and displacement is not None: - # forces come for free forces, virials, stress = compute_forces_virials( energy=energy, positions=positions, @@ -229,11 +227,11 @@ def get_edge_vectors_and_lengths( def _check_non_zero(std): - if std == 0.0: + if np.any(std == 0): logging.warning( "Standard deviation of the scaling is zero, Changing to no scaling" ) - std = 1.0 + std[std == 0] = 1 return std @@ -260,20 +258,25 @@ def compute_mean_std_atomic_inter_energy( atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) avg_atom_inter_es_list = [] + head_list = [] for batch in data_loader: node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), batch.head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] avg_atom_inter_es_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } + head_list.append(batch.head) avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - mean = to_numpy(torch.mean(avg_atom_inter_es)).item() - std = to_numpy(torch.std(avg_atom_inter_es)).item() + head = torch.cat(head_list, dim=0) # [total_n_graphs] + # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + # std = to_numpy(torch.std(avg_atom_inter_es)).item() + mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) std = _check_non_zero(std) return mean, std @@ -283,10 +286,11 @@ def _compute_mean_std_atomic_inter_energy( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes return atom_energies @@ -300,23 +304,36 @@ def compute_mean_rms_energy_forces( atom_energy_list = [] forces_list = [] + head_list = [] + head_batch = [] for batch in data_loader: + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) + head_batch.append(head[batch.batch]) atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) rms = _check_non_zero(rms) return mean, rms @@ -326,10 +343,11 @@ def _compute_mean_rms_energy_forces( batch: Batch, atomic_energies_fn: AtomicEnergiesBlock, ) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } forces = batch.forces # {[n_graphs*n_atoms,3], } @@ -339,7 +357,6 @@ def _compute_mean_rms_energy_forces( def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: num_neighbors = [] - for batch in data_loader: _, receivers = batch.edge_index _, counts = torch.unique(receivers, return_counts=True) @@ -360,17 +377,20 @@ def compute_statistics( atom_energy_list = [] forces_list = [] num_neighbors = [] + head_list = [] for batch in data_loader: + head = batch.head node_e0 = atomic_energies_fn(batch.node_attrs) graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=-1, dim_size=batch.num_graphs - ) + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] graph_sizes = batch.ptr[1:] - batch.ptr[:-1] atom_energy_list.append( (batch.energy - graph_e0s) / graph_sizes ) # {[n_graphs], } forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) # {[n_graphs], } _, receivers = batch.edge_index _, counts = torch.unique(receivers, return_counts=True) @@ -378,9 +398,15 @@ def compute_statistics( atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - - mean = to_numpy(torch.mean(atom_energies)).item() - rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + head = torch.cat(head_list, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + # do the mean for each head + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + rms = to_numpy( + torch.sqrt(scatter_mean(src=torch.square(forces), index=head, dim=0)) + ) avg_num_neighbors = torch.mean( torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 54c59455..8ad80243 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -2,7 +2,7 @@ from .arg_parser_tools import check_args from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .finetuning_utils import load_foundations +from .finetuning_utils import load_foundations, load_foundations_elements from .torch_tools import ( TensorDict, cartesian_to_spherical, @@ -28,7 +28,6 @@ compute_rel_rmse, compute_rmse, get_atomic_number_table_from_zs, - get_optimizer, get_tag, setup_logger, ) @@ -46,7 +45,6 @@ "setup_logger", "get_tag", "count_parameters", - "get_optimizer", "MetricsLogger", "get_atomic_number_table_from_zs", "train", @@ -68,5 +66,6 @@ "voigt_to_matrix", "init_wandb", "load_foundations", + "load_foundations_elements", "build_preprocess_arg_parser", ] diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 2b0e2b56..046f04d6 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -60,7 +60,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--device", help="select device", type=str, - choices=["cpu", "cuda", "mps"], + choices=["cpu", "cuda", "mps", "xpu"], default="cpu", ) parser.add_argument( @@ -228,19 +228,19 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--compute_avg_num_neighbors", help="normalization factor for the message", - type=bool, + type=str2bool, default=True, ) parser.add_argument( "--compute_stress", help="Select True to compute stress", - type=bool, + type=str2bool, default=False, ) parser.add_argument( "--compute_forces", help="Select True to compute forces", - type=bool, + type=str2bool, default=True, ) @@ -249,7 +249,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--train_file", help="Training set file, format is .xyz or .h5", type=str, - required=True, + required=False, ) parser.add_argument( "--valid_file", @@ -280,7 +280,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--multi_processed_test", help="Boolean value for whether the test data was multiprocessed", - type=bool, + type=str2bool, default=False, required=False, ) @@ -294,7 +294,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--pin_memory", help="Pin memory for data loading", default=True, - type=bool, + type=str2bool, ) parser.add_argument( "--atomic_numbers", @@ -331,12 +331,66 @@ def build_default_arg_parser() -> argparse.ArgumentParser: default=None, required=False, ) + + # Fine-tuning + parser.add_argument( + "--foundation_filter_elements", + help="Filter element during fine-tuning", + type=str2bool, + default=True, + required=False, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multiheads_finetuning", + help="Boolean value for whether the model is multiheaded", + type=str2bool, + default=True, + ) + parser.add_argument( + "--weight_pt_head", + help="Weight of the pretrained head in the loss function", + type=float, + default=1.0, + ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=1000, + ) + parser.add_argument( + "--subselect_pt", + help="Method to subselect the configurations of the pretraining set", + choices=["fps", "random"], + default="random", + ) + parser.add_argument( + "--pt_train_file", + help="Training set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--pt_valid_file", + help="Validation set file for the pretrained head", + type=str, + default=None, + ) parser.add_argument( "--keep_isolated_atoms", help="Keep isolated atoms in the dataset, useful for transfer learning", - type=bool, + type=str2bool, default=False, ) + + # Keys parser.add_argument( "--energy_key", help="Key of reference energies in training xyz", @@ -769,7 +823,7 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--shuffle", help="Shuffle the training dataset", - type=bool, + type=str2bool, default=True, ) parser.add_argument( @@ -790,3 +844,13 @@ def check_float_or_none(value: str) -> Optional[float]: f"{value} is an invalid value (float or None)" ) from None return None + + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + if value.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 0aad091b..0d4e2f52 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -3,12 +3,12 @@ from mace.tools.utils import AtomicNumberTable -def load_foundations( +def load_foundations_elements( model: torch.nn.Module, model_foundations: torch.nn.Module, table: AtomicNumberTable, load_readout=False, - use_shift=False, + use_shift=True, use_scale=True, max_L=2, ): @@ -17,6 +17,7 @@ def load_foundations( """ assert model_foundations.r_max == model.r_max z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + model_heads = model.heads new_z_table = table num_species_foundations = len(z_table.zs) num_channels_foundation = ( @@ -39,7 +40,6 @@ def load_foundations( model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() ) - for i in range(int(model.num_interactions)): model.interactions[i].linear_up.weight = torch.nn.Parameter( model_foundations.interactions[i].linear_up.weight.clone() @@ -101,6 +101,7 @@ def load_foundations( .clone() / (num_species_foundations / num_species) ** 0.5 ) + # Transferring products for i in range(2): # Assuming 2 products modules max_range = max_L + 1 if i == 0 else 1 @@ -130,20 +131,62 @@ def load_foundations( if load_readout: # Transferring readouts + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) model.readouts[0].linear.weight = torch.nn.Parameter( - model_foundations.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight ) + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight + ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_heads), len(model_heads) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_foundations.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight ) if model_foundations.scale_shift is not None: if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.clone() + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_heads) + ).clone() if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.clone() + model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( + len(model_heads) + ).clone() + return model + + +def load_foundations( + model, + model_foundations, +): + for name, param in model_foundations.named_parameters(): + if name in model.state_dict().keys(): + if "readouts" not in name: + model.state_dict()[name].copy_(param) return model diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py new file mode 100644 index 00000000..8e8c2877 --- /dev/null +++ b/mace/tools/model_script_utils.py @@ -0,0 +1,229 @@ +import ast +import logging + +import numpy as np +from e3nn import o3 + +from mace import modules +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model + + +def configure_model( + args, train_loader, atomic_energies, model_foundation=None, heads=None, z_table=None +): + # Selecting outputs + compute_virials = args.loss in ("stress", "virials", "huber", "universal") + if compute_virials: + args.compute_stress = True + args.error_table = "PerAtomRMSEstressvirials" + + output_args = { + "energy": args.compute_energy, + "forces": args.compute_forces, + "virials": compute_virials, + "stress": args.compute_stress, + "dipoles": args.compute_dipole, + } + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") + + if args.scaling == "no_scaling": + args.std = 1.0 + logging.info("No scaling selected") + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) + + # Build model + if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: + logging.info("Loading FOUNDATION model") + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + args.max_L = model_config_foundation["hidden_irreps"].lmax + + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + else: + model_config_foundation["atomic_inter_shift"] = ( + _determine_atomic_inter_shift(args.mean, heads) + ) + + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] + args.model = "FoundationMACE" + model_config_foundation["heads"] = heads + model_config = model_config_foundation + + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) + else: + logging.info("Building model") + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + + model_config = dict( + r_max=args.r_max, + num_bessel=args.num_radial_basis, + num_polynomial_cutoff=args.num_cutoff_basis, + max_ell=args.max_ell, + interaction_cls=modules.interaction_classes[args.interaction], + num_interactions=args.num_interactions, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + ) + model_config_foundation = None + + model = _build_model(args, model_config, model_config_foundation, heads) + + if model_foundation is not None: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=args.foundation_filter_elements, + max_L=args.max_L, + ) + + return model, output_args + + +def _determine_atomic_inter_shift(mean, heads): + if isinstance(mean, np.ndarray): + if mean.size == 1: + return mean.item() + if mean.size == len(heads): + return mean.tolist() + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + if isinstance(mean, list) and len(mean) == len(heads): + return mean + if isinstance(mean, float): + return [mean] * len(heads) + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + + +def _build_model( + args, model_config, model_config_foundation, heads +): # pylint: disable=too-many-return-statements + if args.model == "MACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=[0.0] * len(heads), + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "ScaleShiftMACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "FoundationMACE": + return modules.ScaleShiftMACE(**model_config_foundation) + if args.model == "ScaleShiftBOTNet": + return modules.ScaleShiftBOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + ) + if args.model == "BOTNet": + return modules.BOTNet( + **model_config, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "AtomicDipolesMACE": + assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" + assert ( + args.error_table == "DipoleRMSE" + ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" + return modules.AtomicDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "EnergyDipolesMACE": + assert ( + args.loss == "energy_forces_dipole" + ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" + assert ( + args.error_table == "EnergyDipoleRMSE" + ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" + return modules.EnergyDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + raise RuntimeError(f"Unknown model: '{args.model}'") diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py new file mode 100644 index 00000000..ffde107f --- /dev/null +++ b/mace/tools/multihead_tools.py @@ -0,0 +1,185 @@ +import argparse +import dataclasses +import logging +import os +import urllib.request +from typing import Any, Dict, List, Optional, Union + +import torch + +from mace.cli.fine_tuning_select import select_samples +from mace.tools.scripts_utils import ( + SubsetCollection, + dict_to_namespace, + get_dataset_from_xyz, +) + + +@dataclasses.dataclass +class HeadConfig: + head_name: str + train_file: Optional[str] = None + valid_file: Optional[str] = None + test_file: Optional[str] = None + test_dir: Optional[str] = None + E0s: Optional[Any] = None + statistics_file: Optional[str] = None + valid_fraction: Optional[float] = None + config_type_weights: Optional[Dict[str, float]] = None + energy_key: Optional[str] = None + forces_key: Optional[str] = None + stress_key: Optional[str] = None + virials_key: Optional[str] = None + dipole_key: Optional[str] = None + charges_key: Optional[str] = None + keep_isolated_atoms: Optional[bool] = None + atomic_numbers: Optional[Union[List[int], List[str]]] = None + mean: Optional[float] = None + std: Optional[float] = None + avg_num_neighbors: Optional[float] = None + compute_avg_num_neighbors: Optional[bool] = None + collections: Optional[SubsetCollection] = None + train_loader: torch.utils.data.DataLoader = None + z_table: Optional[Any] = None + atomic_energies_dict: Optional[Dict[str, float]] = None + + +def dict_head_to_dataclass( + head: Dict[str, Any], head_name: str, args: argparse.Namespace +) -> HeadConfig: + + return HeadConfig( + head_name=head_name, + train_file=head.get("train_file", args.train_file), + valid_file=head.get("valid_file", args.valid_file), + test_file=head.get("test_file", None), + test_dir=head.get("test_dir", None), + E0s=head.get("E0s", args.E0s), + statistics_file=head.get("statistics_file", args.statistics_file), + valid_fraction=head.get("valid_fraction", args.valid_fraction), + config_type_weights=head.get("config_type_weights", args.config_type_weights), + compute_avg_num_neighbors=head.get( + "compute_avg_num_neighbors", args.compute_avg_num_neighbors + ), + atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), + mean=head.get("mean", args.mean), + std=head.get("std", args.std), + avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), + energy_key=head.get("energy_key", args.energy_key), + forces_key=head.get("forces_key", args.forces_key), + stress_key=head.get("stress_key", args.stress_key), + virials_key=head.get("virials_key", args.virials_key), + dipole_key=head.get("dipole_key", args.dipole_key), + charges_key=head.get("charges_key", args.charges_key), + keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), + ) + + +def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: + return { + "default": { + "train_file": args.train_file, + "valid_file": args.valid_file, + "test_file": args.test_file, + "test_dir": args.test_dir, + "E0s": args.E0s, + "statistics_file": args.statistics_file, + "valid_fraction": args.valid_fraction, + "config_type_weights": args.config_type_weights, + "energy_key": args.energy_key, + "forces_key": args.forces_key, + "stress_key": args.stress_key, + "virials_key": args.virials_key, + "dipole_key": args.dipole_key, + "charges_key": args.charges_key, + "keep_isolated_atoms": args.keep_isolated_atoms, + } + } + + +def assemble_mp_data( + args: argparse.Namespace, tag: str, head_configs: List[HeadConfig] +) -> Dict[str, Any]: + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + descriptors_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/descriptors.npy" + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + descriptors_url_name = "".join( + c for c in os.path.basename(descriptors_url) if c.isalnum() or c in "_" + ) + cached_descriptors_path = f"{cache_dir}/{descriptors_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + if not os.path.isfile(cached_descriptors_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP descriptors for finetuning") + _, http_msg = urllib.request.urlretrieve( + descriptors_url, cached_descriptors_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Descriptors download failed, please check the URL {descriptors_url}" + ) + logging.info(f"Materials Project descriptors to {cached_descriptors_path}") + dataset_mp = cached_dataset_path + descriptors_mp = cached_descriptors_path + msg = f"Using Materials Project dataset with {dataset_mp}" + logging.info(msg) + msg = f"Using Materials Project descriptors with {descriptors_mp}" + logging.info(msg) + config_pt_paths = [head.train_file for head in head_configs] + args_samples = { + "configs_pt": dataset_mp, + "configs_ft": config_pt_paths, + "num_samples": args.num_samples_pt, + "seed": args.seed, + "model": args.foundation_model, + "head_pt": "pbe_mp", + "head_ft": "Default", + "weight_pt": args.weight_pt_head, + "weight_ft": 1.0, + "filtering_type": "combination", + "output": f"mp_finetuning-{tag}.xyz", + "descriptors": descriptors_mp, + "subselect": args.subselect_pt, + "device": args.device, + "default_dtype": args.default_dtype, + } + select_samples(dict_to_namespace(args_samples)) + collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=f"mp_finetuning-{tag}.xyz", + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + energy_key="energy", + forces_key="forces", + stress_key="stress", + head_name="pt_head", + virials_key=args.virials_key, + dipole_key=args.dipole_key, + charges_key=args.charges_key, + keep_isolated_atoms=args.keep_isolated_atoms, + ) + return collections_mp + except Exception as exc: + raise RuntimeError( + "Model or descriptors download failed and no local model found" + ) from exc diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 27455944..f44390a6 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -4,6 +4,7 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +import argparse import ast import dataclasses import json @@ -16,9 +17,11 @@ import torch.distributed from e3nn import o3 from prettytable import PrettyTable +from torch.optim.swa_utils import SWALR, AveragedModel -from mace import data, modules +from mace import data, modules, tools from mace.tools import evaluate +from mace.tools.train import SWAContainer @dataclasses.dataclass @@ -31,18 +34,20 @@ class SubsetCollection: def get_dataset_from_xyz( work_dir: str, train_path: str, - valid_path: str, + valid_path: Optional[str], valid_fraction: float, config_type_weights: Dict, test_path: str = None, seed: int = 1234, keep_isolated_atoms: bool = False, + head_name: str = "Default", energy_key: str = "REF_energy", forces_key: str = "REF_forces", stress_key: str = "REF_stress", virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", + head_key: str = "head", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" atomic_energies_dict, all_train_configs = data.load_from_xyz( @@ -54,8 +59,10 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=True, keep_isolated_atoms=keep_isolated_atoms, + head_name=head_name, ) logging.info( f"Training set [{len(all_train_configs)} configs, {np.sum([1 if config.energy else 0 for config in all_train_configs])} energy, {np.sum([config.forces.size for config in all_train_configs])} forces] loaded from '{train_path}'" @@ -70,7 +77,9 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=False, + head_name=head_name, ) logging.info( f"Validation set [{len(valid_configs)} configs, {np.sum([1 if config.energy else 0 for config in valid_configs])} energy, {np.sum([config.forces.size for config in valid_configs])} forces] loaded from '{valid_path}'" @@ -95,7 +104,9 @@ def get_dataset_from_xyz( stress_key=stress_key, virials_key=virials_key, charges_key=charges_key, + head_key=head_key, extract_atomic_energies=False, + head_name=head_name, ) # create list of tuples (config_type, list(Atoms)) test_configs = data.test_config_types(all_test_configs) @@ -163,6 +174,8 @@ def radial_to_transform(radial): return "Soft" return radial.distance_transform.__class__.__name__ + scale = model.scale_shift.scale + shift = model.scale_shift.shift config = { "r_max": model.r_max.item(), "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), @@ -198,8 +211,8 @@ def radial_to_transform(radial): "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], "pair_repulsion": hasattr(model, "pair_repulsion_fn"), "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": model.scale_shift.scale.item(), - "atomic_inter_shift": model.scale_shift.shift.item(), + "atomic_inter_scale": scale.cpu().numpy(), + "atomic_inter_shift": shift.cpu().numpy(), } return config @@ -316,7 +329,14 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: atomic_energies_dict = json.load(f) else: try: - atomic_energies_dict = ast.literal_eval(E0s) + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = atomic_energies_eval + else: + atomic_energies_dict = atomic_energies_eval assert isinstance(atomic_energies_dict, dict) except Exception as e: raise RuntimeError( @@ -329,55 +349,260 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: return atomic_energies_dict +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + def get_loss_fn( - loss: str, - energy_weight: float, - forces_weight: float, - stress_weight: float, - virials_weight: float, - dipole_weight: float, + args: argparse.Namespace, dipole_only: bool, compute_dipole: bool, ) -> torch.nn.Module: - if loss == "weighted": + if args.loss == "weighted": loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight + energy_weight=args.energy_weight, forces_weight=args.forces_weight ) - elif loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=forces_weight) - elif loss == "virials": + elif args.loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) + elif args.loss == "virials": loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - virials_weight=virials_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + virials_weight=args.virials_weight, ) - elif loss == "stress": + elif args.loss == "stress": loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, ) - elif loss == "dipole": + elif args.loss == "huber": + loss_fn = modules.WeightedHuberEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "universal": + loss_fn = modules.UniversalLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "dipole": assert ( dipole_only is True ), "dipole loss can only be used with AtomicDipolesMACE model" loss_fn = modules.DipoleSingleLoss( - dipole_weight=dipole_weight, + dipole_weight=args.dipole_weight, ) - elif loss == "energy_forces_dipole": + elif args.loss == "energy_forces_dipole": assert dipole_only is False and compute_dipole is True loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=energy_weight, - forces_weight=forces_weight, - dipole_weight=dipole_weight, + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + dipole_weight=args.dipole_weight, ) else: - loss_fn = modules.EnergyForcesLoss( - energy_weight=energy_weight, forces_weight=forces_weight - ) + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) return loss_fn +def get_swa( + args: argparse.Namespace, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + swas: List[bool], + dipole_only: bool = False, +): + assert dipole_only is False, "Stage Two for dipole fitting not implemented" + swas.append(True) + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa > args.max_num_epochs: + logging.warning( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + if args.loss == "forces_only": + raise ValueError("Can not select Stage Two with forces only loss.") + if args.loss == "virials": + loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + virials_weight=args.swa_virials_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "stress": + loss_fn_energy = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.stress_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "energy_forces_dipole": + loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( + args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + dipole_weight=args.swa_dipole_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + else: + loss_fn_energy = modules.WeightedEnergyForcesLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + ) + swa = SWAContainer( + model=AveragedModel(model), + scheduler=SWALR( + optimizer=optimizer, + swa_lr=args.swa_lr, + anneal_epochs=1, + anneal_strategy="linear", + ), + start=args.start_swa, + loss_fn=loss_fn_energy, + ) + return swa, swas + + +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + args_dict_json = json.dumps(args_dict) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: return [ os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) @@ -397,6 +622,23 @@ def custom_key(key): return (2, key) +def dict_to_array(input_data, heads): + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array(list(input_data.values())) + unique_keys = set() + for inner_dict in input_data.values(): + unique_keys.update(inner_dict.keys()) + unique_keys = list(unique_keys) + sorted_keys = sorted([int(key) for key in unique_keys]) + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): + for key, value in inner_dict.items(): + key_index = sorted_keys.index(int(key)) + head_index = heads.index(head_name) + result_array[head_index][key_index] = value + return np.squeeze(result_array) + + class LRScheduler: def __init__(self, optimizer, args) -> None: self.scheduler = args.scheduler @@ -651,3 +893,20 @@ def create_error_table( ] ) return table + + +def check_folder_subfolder(folder_path): + entries = os.listdir(folder_path) + for entry in entries: + full_path = os.path.join(folder_path, entry) + if os.path.isdir(full_path): + return True + return False + + +def dict_to_namespace(dictionary): + # Convert the dictionary into an argparse.Namespace + namespace = argparse.Namespace() + for key, value in dictionary.items(): + setattr(namespace, key, value) + return namespace diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index 1ec3ecde..e42a74f8 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -64,6 +64,9 @@ def init_device(device_str: str) -> torch.device: assert torch.backends.mps.is_available(), "No MPS backend is available!" logging.info("Using MPS GPU acceleration") return torch.device("mps") + if device_str == "xpu": + torch.xpu.is_available() + return torch.device("xpu") logging.info("Using CPU") return torch.device("cpu") diff --git a/mace/tools/train.py b/mace/tools/train.py index b38bce16..1ab86f82 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -41,7 +41,14 @@ class SWAContainer: loss_fn: torch.nn.Module -def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): +def valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch=None, + valid_loader_name="Default", +): eval_metrics["mode"] = "eval" eval_metrics["epoch"] = epoch logger.log(eval_metrics) @@ -53,17 +60,17 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" ) elif ( log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress_per_atom"] is not None + and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress_per_atom"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress_per_atom={error_stress:8.1f} meV / A^3", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_stress={error_stress:8.1f} meV / A^3", ) elif ( log_errors == "PerAtomRMSEstressvirials" @@ -73,7 +80,7 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_f = eval_metrics["rmse_f"] * 1e3 error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_virials_per_atom={error_virials:8.1f} meV", ) elif ( log_errors == "PerAtomMAEstressvirials" @@ -99,31 +106,31 @@ def valid_err_log(valid_loss, eval_metrics, logger, log_errors, epoch=None): error_e = eval_metrics["rmse_e"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A", ) elif log_errors == "PerAtomMAE": error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E_per_atom={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "TotalMAE": error_e = eval_metrics["mae_e"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, MAE_E={error_e:8.1f} meV, MAE_F={error_f:8.1f} meV / A", ) elif log_errors == "DipoleRMSE": error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", ) elif log_errors == "EnergyDipoleRMSE": error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 logging.info( - f"{inintial_phrase}: loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", ) @@ -131,7 +138,7 @@ def train( model: torch.nn.Module, loss_fn: torch.nn.Module, train_loader: DataLoader, - valid_loader: Dict[str, DataLoader], + valid_loaders: Dict[str, DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, start_epoch: int, @@ -170,17 +177,20 @@ def train( logging.info("Loss metrics on validation set") epoch = start_epoch - # # log validation loss before _any_ training - param_context = ema.average_parameters() if ema is not None else nullcontext() - with param_context: - valid_loss, eval_metrics = evaluate( + # log validation loss before _any_ training + valid_loss = 0.0 + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( model=model, loss_fn=loss_fn, data_loader=valid_loader, output_args=output_args, device=device, ) - valid_err_log(valid_loss, eval_metrics, logger, log_errors, None) + valid_err_log( + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name + ) + valid_loss = valid_loss_head # consider only the last head for the checkpoint while epoch < max_num_epochs: # LR scheduler and SWA update @@ -233,30 +243,40 @@ def train( if "ScheduleFree" in type(optimizer).__name__: optimizer.eval() with param_context: - valid_loss, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, + valid_loss = 0.0 + wandb_log_dict = {} + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + if log_wandb: + wandb_log_dict[valid_loader_name] = { + "epoch": epoch, + "valid_loss": valid_loss_head, + "valid_rmse_e_per_atom": eval_metrics[ + "rmse_e_per_atom" + ], + "valid_rmse_f": eval_metrics["rmse_f"], + } + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint ) + if log_wandb: + wandb.log(wandb_log_dict) if rank == 0: - valid_err_log( - valid_loss, - eval_metrics, - logger, - log_errors, - epoch, - ) - if log_wandb: - wandb_log_dict = { - "epoch": epoch, - "valid_loss": valid_loss, - "valid_rmse_e_per_atom": eval_metrics["rmse_e_per_atom"], - "valid_rmse_f": eval_metrics["rmse_f"], - } - wandb.log(wandb_log_dict) - if valid_loss >= lowest_loss: patience_counter += 1 if patience_counter >= patience and epoch < swa.start: @@ -422,7 +442,6 @@ def __init__(self, loss_fn: torch.nn.Module): "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state("delta_stress_per_atom", default=[], dist_reduce_fx="cat") self.add_state( "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" ) @@ -451,10 +470,6 @@ def update(self, batch, output): # pylint: disable=arguments-differ if output.get("stress") is not None and batch.stress is not None: self.stress_computed += 1.0 self.delta_stress.append(batch.stress - output["stress"]) - self.delta_stress_per_atom.append( - (batch.stress - output["stress"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) if output.get("virials") is not None and batch.virials is not None: self.virials_computed += 1.0 self.delta_virials.append(batch.virials - output["virials"]) @@ -497,10 +512,8 @@ def compute(self): aux["q95_f"] = compute_q95(delta_fs) if self.stress_computed: delta_stress = self.convert(self.delta_stress) - delta_stress_per_atom = self.convert(self.delta_stress_per_atom) aux["mae_stress"] = compute_mae(delta_stress) aux["rmse_stress"] = compute_rmse(delta_stress) - aux["rmse_stress_per_atom"] = compute_rmse(delta_stress_per_atom) aux["q95_stress"] = compute_q95(delta_stress) if self.virials_computed: delta_virials = self.convert(self.delta_virials) diff --git a/mace/tools/utils.py b/mace/tools/utils.py index 762d9880..28a77efe 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -121,26 +121,6 @@ def atomic_numbers_to_indices( return to_index_fn(atomic_numbers) -def get_optimizer( - name: str, - amsgrad: bool, - learning_rate: float, - weight_decay: float, - parameters: Iterable[torch.Tensor], -) -> torch.optim.Optimizer: - if name == "adam": - return torch.optim.Adam( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - if name == "adamw": - return torch.optim.AdamW( - parameters, lr=learning_rate, amsgrad=amsgrad, weight_decay=weight_decay - ) - - raise RuntimeError(f"Unknown optimizer '{name}'") - - class UniversalEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, np.integer): @@ -161,7 +141,6 @@ def __init__(self, directory: str, tag: str) -> None: self.path = os.path.join(self.directory, self.filename) def log(self, d: Dict[str, Any]) -> None: - logging.debug(f"Saving info: {self.path}") os.makedirs(name=self.directory, exist_ok=True) with open(self.path, mode="a", encoding="utf-8") as f: f.write(json.dumps(d, cls=UniversalEncoder)) diff --git a/pyproject.toml b/pyproject.toml index 05037dc7..489bc6e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ disable = [ "not-callable", "logging-fstring-interpolation", "logging-not-lazy", + "logging-too-many-args", "invalid-name", "too-few-public-methods", "too-many-instance-attributes", diff --git a/setup.cfg b/setup.cfg index 81e0b661..13d55161 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = python-hostlist configargparse GitPython + pyYAML tqdm # for plotting: matplotlib @@ -40,9 +41,11 @@ console_scripts = mace_plot_train = mace.cli.plot_train:main mace_run_train = mace.cli.run_train:main mace_prepare_data = mace.cli.preprocess_data:main + mace_finetuning = mace.cli.fine_tuning_select:main [options.extras_require] wandb = wandb +fpsample = fpsample dev = black isort diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 73019b4a..8ff87936 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -6,6 +6,7 @@ import ase.io import numpy as np import pytest +import torch from ase import build from ase.atoms import Atoms from ase.calculators.test import gradient_test @@ -381,12 +382,12 @@ def test_calculator_node_energy(fitting_configs, trained_model): trained_model.calculate(at) node_energies = trained_model.results["node_energy"] batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) node_e0 = ( - trained_model.models[0] - .atomic_energies_fn(batch["node_attrs"]) - .detach() - .numpy() + trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() ) + node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() energy_via_nodes = np.sum(node_energies + node_e0) energy = trained_model.results["energy"] np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) diff --git a/tests/test_data.py b/tests/test_data.py index e893f03c..9e0c49e6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -144,7 +144,7 @@ def test_basic(self): ] ) - indices, shifts, unit_shifts = get_neighborhood(positions, cutoff=1.5) + indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) assert indices.shape == (2, 4) assert shifts.shape == (4, 3) assert unit_shifts.shape == (4, 3) @@ -158,7 +158,7 @@ def test_signs(self): ) cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - edge_index, shifts, unit_shifts = get_neighborhood( + edge_index, shifts, unit_shifts, _ = get_neighborhood( positions, cutoff=3.5, pbc=(True, False, False), cell=cell ) num_edges = 10 @@ -172,7 +172,7 @@ def test_periodic_edge(): atoms = ase.build.bulk("Cu", "fcc") dist = np.linalg.norm(atoms.cell[0]).item() config = config_from_atoms(atoms) - edge_index, shifts, _ = get_neighborhood( + edge_index, shifts, _, _ = get_neighborhood( config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell ) sender, receiver = edge_index @@ -190,7 +190,7 @@ def test_half_periodic(): atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) assert all(atoms.pbc == (True, True, False)) config = config_from_atoms(atoms) # first shell dist is 2.864A - edge_index, shifts, _ = get_neighborhood( + edge_index, shifts, _, _ = get_neighborhood( config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell ) sender, receiver = edge_index diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 69963c67..b1724629 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -9,7 +9,7 @@ from mace import data, modules, tools from mace.calculators import mace_mp, mace_off from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations +from mace.tools.finetuning_utils import load_foundations_elements from mace.tools.scripts_utils import extract_config_mace_model from mace.tools.utils import AtomicNumberTable @@ -71,7 +71,7 @@ def test_foundations(): default_dtype="float64", ) model_foundations = calc.models[0] - model_loaded = load_foundations( + model_loaded = load_foundations_elements( model, model_foundations, table=table, @@ -96,6 +96,75 @@ def test_foundations(): assert torch.allclose(forces, forces_loaded) +def test_multi_reference(): + config_multi = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + forces=molecule("H2COH").positions, + energy=-1.5, + charges=molecule("H2COH").numbers, + dipole=np.array([-1.5, 1.5, 2.0]), + head="MP2", + ) + table_multi = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=61, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.0], + heads=["MP2", "DFT"], + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config( + config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch)["forces"] + calc_foundation = mace_mp(device="cpu", default_dtype="float64") + atoms = molecule("H2COH") + atoms.calc = calc_foundation + forces = atoms.get_forces() + assert np.allclose( + forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 + ) + + @pytest.mark.parametrize( "model", [ @@ -113,12 +182,8 @@ def test_extract_config(model): model_copy.load_state_dict(model.state_dict()) z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=z_table, cutoff=6.0 - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], + dataset=[atomic_data, atomic_data], batch_size=2, shuffle=True, drop_last=False, diff --git a/tests/test_models.py b/tests/test_models.py index 18ef536b..8e8c60da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -196,3 +196,56 @@ def test_energy_dipole_mace(): np.array(rot @ output["dipole"][0].detach().numpy()), output["dipole"][1].detach().numpy(), ) + + +def test_mace_multi_reference(): + atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("96x0e + 96x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=8, + atomic_numbers=table.zs, + distance_transform=True, + pair_repulsion=True, + correlation=3, + heads=["Default", "dft"], + # radial_type="chebyshev", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.1], + ) + model = modules.ScaleShiftMACE(**model_config) + model_compiled = jit.compile(model) + config.head = "Default" + config_rotated.head = "dft" + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert output2["energy"].shape[0] == 2 diff --git a/tests/test_modules.py b/tests/test_modules.py index a5166bd2..5539ceb1 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -106,7 +106,7 @@ def test_atomic_energies(self): ) batch = next(iter(data_loader)) - energies = energies_block(batch.node_attrs) + energies = energies_block(batch.node_attrs).squeeze(-1) out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") out = to_numpy(out) assert np.allclose(out, np.array([5.0, 5.0])) diff --git a/tests/test_run_train.py b/tests/test_run_train.py index 59f7c595..fe6c8c46 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -278,6 +278,114 @@ def test_run_train_no_stress(tmp_path, fitting_configs): assert np.allclose(Es, ref_Es) +def test_run_train_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + fitting_configs_ccd = [] + for _, c in enumerate(fitting_configs): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + + c_ccd = c.copy() + c_ccd.info["head"] = "CCD" + fitting_configs_ccd.append(c_ccd) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["config"] = tmp_path / "config.yaml" + mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 02/09/2024 on develop branch + ref_Es = [ + 0.0, + 0.0, + 0.10637113905361611, + -0.012499594026624754, + 0.08983077108171753, + 0.21071322543112597, + -0.028921849222784398, + -0.02423359575741567, + 0.022923252188079057, + -0.02048334610058991, + 0.4349711162741364, + -0.04455577015569887, + -0.09765806785570091, + 0.16013134616829822, + 0.0758442928017698, + -0.05931856557011721, + 0.33964473532953265, + 0.134338442158641, + 0.18024119757783053, + -0.18914740992058765, + -0.06503477155294624, + 0.03436649147415213, + ] + assert np.allclose(Es, ref_Es) + + def test_run_train_foundation(tmp_path, fitting_configs): ase.io.write(tmp_path / "fit.xyz", fitting_configs) @@ -289,9 +397,13 @@ def test_run_train_foundation(tmp_path, fitting_configs): mace_params["foundation_model"] = "small" mace_params["hidden_irreps"] = "128x0e" mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float32" + mace_params["default_dtype"] = "float64" mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + print("mace_params", mace_params) + # mace_params["num_samples_pt"] = 50 + # mace_params["subselect_pt"] = "random" # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -315,7 +427,7 @@ def test_run_train_foundation(tmp_path, fitting_configs): assert p.returncode == 0 calc = MACECalculator( - tmp_path / "MACE.model", device="cpu", default_dtype="float32" + tmp_path / "MACE.model", device="cpu", default_dtype="float64" ) Es = [] @@ -350,3 +462,112 @@ def test_run_train_foundation(tmp_path, fitting_configs): 0.7301387786865234, ] assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1)