Skip to content

Commit

Permalink
element depend radii
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Sep 10, 2024
1 parent 172706c commit 5b10f63
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 27 deletions.
34 changes: 32 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from pathlib import Path
from typing import Optional
import urllib.request
import itertools


import ase
import numpy as np
import torch.distributed
import torch.nn.functional
Expand Down Expand Up @@ -143,6 +144,19 @@ def main() -> None:
if args.heads is not None:
args.heads = Box(ast.literal_eval(args.heads)) # using box container for both dict and namespace access

if args.r_max == "covalent_radii":
radii_dict = dict()
scale = 0.5 * 10
covalent_radii = torch.tensor(ase.data.covalent_radii)
ne = covalent_radii.size(0)
r_max_matrix = (covalent_radii.repeat(ne, 1) + covalent_radii.repeat(ne, 1).t()).cpu().numpy() * scale
r_max_dict = {}
for i in range(ne):
for j in range(ne):
r_max_dict[(i, j)] = r_max_matrix[i][j]
args.r_max_matrix = r_max_matrix
args.r_max_dict = r_max_dict

for head, head_args in args.heads.items():
logging.info(f"============= Processing head {head} ===========")

Expand Down Expand Up @@ -198,6 +212,15 @@ def main() -> None:

# overwright args.r_max with head specific r_max
head_args.r_max = head_args.get('r_max', args.r_max)

if isinstance(head_args.r_max, str):
if head_args.r_max == "covalent_radii":
assert args.r_max == head_args.r_max
head_args.r_max = args.r_max_dict
head_args.r_max_matrix = args.r_max_matrix
else:
raise NotImplementedError(f"r_max type {head_args.r_max} not supported")


# Data preparation
atomic_energies_dict = {k: v.E0s for k, v in args.heads.items()}
Expand Down Expand Up @@ -285,6 +308,10 @@ def main() -> None:

logging.info(f"Dataset {head} subsampled size --> {format_number(len(head_args.train_set))}")

if not isinstance(head_args.r_max, float):
# plot distributio
pass

# head specific train_sampler
head_args.train_sampler = None
if args.distributed:
Expand Down Expand Up @@ -528,8 +555,11 @@ def main() -> None:
), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L"

logging.info(f"Hidden irreps: {args.hidden_irreps}")

# element dependant radial

model_config = dict(
r_max=args.r_max, # TODO: different r_max for heads
r_max=args.r_max if isinstance(args.r_max, float) else args.r_max_matrix, # TODO: different r_max for heads
num_bessel=args.num_radial_basis,
num_polynomial_cutoff=args.num_cutoff_basis,
max_ell=args.max_ell,
Expand Down
4 changes: 2 additions & 2 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

from typing import Optional, Sequence
from typing import Optional, Sequence, Union, Mapping, Tuple

import torch.utils.data

Expand Down Expand Up @@ -114,7 +114,7 @@ def from_config(
cls,
config: Configuration,
z_table: AtomicNumberTable,
cutoff: float,
cutoff: Union[float, Mapping],
heads: Optional[list] = ["Default"],
) -> "AtomicData":
edge_index, shifts, unit_shifts = get_neighborhood(
Expand Down
4 changes: 2 additions & 2 deletions mace/data/neighborhood.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import numpy as np
from matscipy.neighbours import neighbour_list


def get_neighborhood(
positions: np.ndarray, # [num_positions, 3]
cutoff: float,
cutoff: Union[float, dict],
pbc: Optional[Tuple[bool, bool, bool]] = None,
cell: Optional[np.ndarray] = None, # [3, 3]
true_self_interaction=False,
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __repr__(self):
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: float,
r_max: Union[float, np.ndarray],
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
Expand Down
3 changes: 2 additions & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
class MACE(torch.nn.Module):
def __init__(
self,
r_max: float,
r_max: Union[float, np.ndarray],
num_bessel: int,
num_polynomial_cutoff: int,
max_ell: int,
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)

self.register_buffer(
"num_interactions", torch.tensor(num_interactions, dtype=torch.int64)
)
Expand Down
54 changes: 39 additions & 15 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
from e3nn.util.jit import compile_mode

from typing import Union

from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum

Expand Down Expand Up @@ -117,25 +119,47 @@ class PolynomialCutoff(torch.nn.Module):
p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6):
def __init__(self, r_max: Union[float, np.ndarray], p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
self.elem_dept = (len(self.r_max.shape) == 2)

def forward(self,
x: torch.Tensor,
node_attrs: torch.Tensor = None,
edge_index: torch.Tensor = None,
atomic_numbers: torch.Tensor = None,
) -> torch.Tensor:
# yapf: disable
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
)
# yapf: enable

# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
if not self.elem_dept:
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
)
# yapf: enable
return envelope * (x < self.r_max)
else:
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]

r_max = self.r_max[Z_u][Z_v]
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
)
return envelope * (x < r_max)

def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"
Expand All @@ -146,12 +170,12 @@ class ZBLBasis(torch.nn.Module):
"""
Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6, trainable=False):
def __init__(self, r_max: Union[float, np.ndarray], p=6, trainable=False):
super().__init__()

self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
Expand Down
1 change: 1 addition & 0 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import partial
tqdm = partial(tqdm, ncols=55)
import torch.distributed as dist
import torch_geometric

def compute_forces(
energy: torch.Tensor, positions: torch.Tensor, training: bool = True
Expand Down
20 changes: 16 additions & 4 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import argparse
import os
from typing import Optional
from typing import Optional, Union


def build_default_arg_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -106,16 +106,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--agnostic_int",
default=[False, False],
nargs='+'
nargs='+',
type=list_of_bools,
)
parser.add_argument(
"--agnostic_con",
default=[False, False],
nargs='+'
nargs='+',
type=list_of_bools,
)

parser.add_argument(
"--r_max", help="distance cutoff (in Ang)", type=float, default=5.0
"--r_max", help="distance cutoff (in Ang)", type=float_or_str, default=5.0
)
parser.add_argument(
"--radial_type",
Expand Down Expand Up @@ -870,3 +872,13 @@ def check_float_or_none(value: str) -> Optional[float]:
f"{value} is an invalid value (float or None)"
) from None
return None

def float_or_str(value: str) -> Union[float, str]:
try:
return float(value)
except ValueError:
return value

def list_of_bools(value: str) -> bool:
return eval(value)

5 changes: 5 additions & 0 deletions multihead_config/get_r_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import ase
import ase.data.covalent_radii as radii

def radii_from_z_pair(z1, z2):
return (radii(z1) + radii(z2)) / 2.0
15 changes: 15 additions & 0 deletions multihead_config/jz_mp_config_eledp_r.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
avg_num_neighbor_head: mp_pbe
device: cuda
multi_processed_test: True
heads:
mp_pbe:
train_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/train/MatProj
valid_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/valid/MatProj
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
#avg_num_neighbors: 61.9649349317854
#mean: 0.1634233391135065
#std: 0.7735790334431056
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
1 change: 1 addition & 0 deletions scripts_tuning_for_md/debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
srun --pty --account gax@h100 -C h100 --nodes=1 --ntasks-per-node=1 --cpus-per-task=12 --gres=gpu:1 --time=20:00:00 --hint=nomultithread bash
61 changes: 61 additions & 0 deletions scripts_tuning_for_md/run_multihead_5arg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/bin/bash
DATA_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_dataset
module load pytorch-gpu/py3/2.3.1
export PATH="$PATH:/linkhome/rech/genrre01/unh55hx/.local/bin"
REAL_BATCH_SIZE=$(($1 * $3))
CONF=$4
R=$5
NUM_CHANNEL=$6
NUM_RADIAL=$7
MLP_IRREPS=$8
SEED=$9
ROOT_DIR=/lustre/fsn1/projects/rech/gax/unh55hx/mace_multi_head_interface
conf_str="${CONF%.yaml}"
stress=${10}
cd $ROOT_DIR
mace_run_train \
--name="MACE_medium_stress${stress}_nc${NUM_CHANNEL}_nr${NUM_RADIAL}_MLP${MLP_IRREPS}_agnesi_b${REAL_BATCH_SIZE}_lr$2_${conf_str}_md_tune_nonagnostic" \
--loss='universal' \
--energy_weight=1 \
--forces_weight=10 \
--compute_stress=True \
--stress_weight=${stress} \
--eval_interval=1 \
--error_table='PerAtomMAE' \
--model="MACE" \
--interaction_first="RealAgnosticInteractionBlock" \
--interaction="RealAgnosticResidualInteractionBlock" \
--num_interactions=2 \
--correlation=3 \
--max_ell=3 \
--r_max=${R} \
--max_L=1 \
--num_channels=${NUM_CHANNEL} \
--num_radial_basis=${NUM_RADIAL} \
--MLP_irreps=${MLP_IRREPS} \
--scaling='rms_forces_scaling' \
--lr=$2 \
--weight_decay=1e-8 \
--ema \
--ema_decay=0.995 \
--scheduler_patience=5 \
--batch_size=$1 \
--valid_batch_size=32 \
--pair_repulsion \
--distance_transform="Agnesi" \
--max_num_epochs=300 \
--patience=40 \
--amsgrad \
--seed=${SEED} \
--clip_grad=100 \
--keep_checkpoints \
--restart_latest \
--save_cpu \
--config="multihead_config/${CONF}" \
--device=cuda \
--num_workers=0 \
--distributed \
--agnostic_int False False \
--agnostic_con False False \

# --name="MACE_medium_agnesi_b32_origin_mponly" \
Loading

0 comments on commit 5b10f63

Please sign in to comment.