Skip to content

Commit

Permalink
Merge pull request #578 from ACEsuit/develop
Browse files Browse the repository at this point in the history
multihead finetuning
  • Loading branch information
ilyes319 committed Sep 12, 2024
2 parents 22a2e3e + 6d768c2 commit 50e5bd1
Show file tree
Hide file tree
Showing 38 changed files with 2,515 additions and 826 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Linting and code formatting

on: []
# Trigger the workflow on push or pull request,
# but only for the main branch
# push:
# branches: []
# pull_request:
# branches: []


jobs:
build-linux:
runs-on: ubuntu-latest

steps:
# Setup
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8.10
- name: Get cache
uses: actions/cache@v2
with:
path: /opt/hostedtoolcache/Python/3.8.10/x64/lib/python3.8/site-packages
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}

# Install packages
- name: Install packages required for installation
run: python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: pip install -r requirements.txt

# Check code
- name: Check formatting with yapf
run: python -m yapf --style=.style.yapf --diff --recursive .
# - name: Lint with flake8
# run: flake8 --config=.flake8 .
# - name: Check type annotations with mypy
# run: mypy --config-file=.mypy.ini .

- name: Test with pytest
run: python -m pytest tests
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ dist/

# DS_Store
.DS_Store
*.models
*.pt
/wandb
*.xyz
/checkpoints
*.model
4 changes: 3 additions & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
__version__ = "0.3.6"
__version__ = "0.3.7"

__all__ = ["__version__"]
6 changes: 3 additions & 3 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def mace_mp(
try:
# checkpoints release: https://github.com/ACEsuit/mace-mp/releases/tag/mace_mp_0
urls = dict(
small="https://tinyurl.com/46jrkm3v", # 2023-12-10-mace-128-L0_energy_epoch-249.model
medium="https://tinyurl.com/5yyxdm76", # 2023-12-03-mace-128-L1_epoch-199.model
large="https://tinyurl.com/5f5yavf3", # MACE_MPtrj_2022.9.model
small="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", # 2023-12-10-mace-128-L0_energy_epoch-249.model
medium="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", # 2023-12-03-mace-128-L1_epoch-199.model
large="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", # MACE_MPtrj_2022.9.model
)
checkpoint_url = (
urls.get(model, urls["medium"])
Expand Down
91 changes: 91 additions & 0 deletions mace/calculators/foundations_models/mp_vasp_e0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"pbe": {
"1": -1.11734008,
"2": 0.00096759,
"3": -0.29754725,
"4": -0.01781697,
"5": -0.26885011,
"6": -1.26173507,
"7": -3.12438806,
"8": -1.54838784,
"9": -0.51882044,
"10": -0.01241601,
"11": -0.22883163,
"12": -0.00951015,
"13": -0.21630193,
"14": -0.8263903,
"15": -1.88816619,
"16": -0.89160769,
"17": -0.25828273,
"18": -0.04925973,
"19": -0.22697913,
"20": -0.0927795,
"21": -2.11396364,
"22": -2.50054871,
"23": -3.70477179,
"24": -5.60261985,
"25": -5.32541181,
"26": -3.52004933,
"27": -1.93555024,
"28": -0.9351969,
"29": -0.60025846,
"30": -0.1651332,
"31": -0.32990651,
"32": -0.77971828,
"33": -1.68367812,
"34": -0.76941032,
"35": -0.22213843,
"36": -0.0335879,
"37": -0.1881724,
"38": -0.06826294,
"39": -2.17084228,
"40": -2.28579303,
"41": -3.13429753,
"42": -4.60211419,
"43": -3.45201492,
"44": -2.38073513,
"45": -1.46855515,
"46": -1.4773126,
"47": -0.33954585,
"48": -0.16843877,
"49": -0.35470981,
"50": -0.83642657,
"51": -1.41101987,
"52": -0.65740879,
"53": -0.18964571,
"54": -0.00857582,
"55": -0.13771876,
"56": -0.03457659,
"57": -0.45580806,
"58": -1.3309175,
"59": -0.29671824,
"60": -0.30391193,
"61": -0.30898427,
"62": -0.25470891,
"63": -8.38001538,
"64": -10.38896525,
"65": -0.3059505,
"66": -0.30676216,
"67": -0.30874667,
"69": -0.25190039,
"70": -0.06431414,
"71": -0.31997586,
"72": -3.52770927,
"73": -3.54492209,
"75": -4.70108713,
"76": -2.88257209,
"77": -1.46779304,
"78": -0.50269936,
"79": -0.28801193,
"80": -0.12454674,
"81": -0.31737194,
"82": -0.77644932,
"83": -1.32627283,
"89": -0.26827152,
"90": -0.90817426,
"91": -2.47653193,
"92": -4.90438537,
"93": -7.63378961,
"94": -10.77237713
}
}
13 changes: 12 additions & 1 deletion mace/calculators/lammps_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,22 @@

@compile_mode("script")
class LAMMPS_MACE(torch.nn.Module):
def __init__(self, model):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.register_buffer("atomic_numbers", model.atomic_numbers)
self.register_buffer("r_max", model.r_max)
self.register_buffer("num_interactions", model.num_interactions)
if not hasattr(model, "heads"):
model.heads = [None]
self.register_buffer(
"head",
torch.tensor(
self.model.heads.index(kwargs.get("head", self.model.heads[-1])),
dtype=torch.long,
).unsqueeze(0),
)

for param in self.model.parameters():
param.requires_grad = False

Expand All @@ -27,6 +37,7 @@ def forward(
compute_displacement = False
if compute_virials:
compute_displacement = True
data["head"] = self.head
out = self.model(
data,
training=False,
Expand Down
12 changes: 10 additions & 2 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def __init__(
[int(z) for z in self.models[0].atomic_numbers]
)
self.charges_key = charges_key
try:
self.heads = self.models[0].heads
except AttributeError:
self.heads = ["Default"]
model_dtype = get_model_dtype(self.models[0])
if default_dtype == "":
print(
Expand Down Expand Up @@ -198,7 +202,7 @@ def _atoms_to_batch(self, atoms):
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads
)
],
batch_size=1,
Expand Down Expand Up @@ -231,7 +235,11 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):

if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = self._clone_batch(batch_base)
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])
node_heads = batch["head"][batch["batch"]]
num_atoms_arange = torch.arange(batch["positions"].shape[0])
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[
num_atoms_arange, node_heads
]
compute_stress = not self.use_compile
else:
compute_stress = False
Expand Down
1 change: 0 additions & 1 deletion mace/cli/active_learning_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def main() -> None:


def run(args: argparse.Namespace) -> None:

mace_fname = args.model
atoms_fname = args.config
atoms_index = args.config_index
Expand Down
65 changes: 60 additions & 5 deletions mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,73 @@
import sys
import argparse

import torch
from e3nn.util import jit

from mace.calculators import LAMMPS_MACE


def main():
assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} model_path"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"model_path",
type=str,
help="Path to the model to be converted to LAMMPS",
)
parser.add_argument(
"--head",
type=str,
nargs="?",
help="Head of the model to be converted to LAMMPS",
default=None,
)
return parser.parse_args()


def select_head(model):
if hasattr(model, "heads"):
heads = model.heads
else:
heads = [None]

if len(heads) == 1:
print(f"Only one head found in the model: {heads[0]}. Skipping selection.")
return heads[0]

print("Available heads in the model:")
for i, head in enumerate(heads):
print(f"{i + 1}: {head}")

# Ask the user to select a head
selected = input(
f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): "
)

model_path = sys.argv[1] # takes model name as command-line input
if selected.isdigit() and 1 <= int(selected) <= len(heads):
return heads[int(selected) - 1]
if selected == "":
print("No head selected. Proceeding without specifying a head.")
return None
print(f"No valid selection made. Defaulting to the last head: {heads[-1]}")
return heads[-1]


def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
model = torch.load(model_path)
model = model.double().to("cpu")
lammps_model = LAMMPS_MACE(model)

if args.head is None:
head = select_head(model)
else:
head = args.head
print(
f"Selected head: {head} from command line in the list available heads: {model.heads}"
)

lammps_model = (
LAMMPS_MACE(model, head=head) if head is not None else LAMMPS_MACE(model)
)
lammps_model_compiled = jit.compile(lammps_model)
lammps_model_compiled.save(model_path + "-lammps.pt")

Expand Down
1 change: 0 additions & 1 deletion mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def main() -> None:


def run(args: argparse.Namespace) -> None:

torch_tools.set_default_dtype(args.default_dtype)
device = torch_tools.init_device(args.device)

Expand Down
Loading

0 comments on commit 50e5bd1

Please sign in to comment.