Skip to content

Commit

Permalink
[debuged] element depend radii
Browse files Browse the repository at this point in the history
  • Loading branch information
birdyLinch committed Sep 10, 2024
1 parent 5b10f63 commit 9a170c2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
16 changes: 11 additions & 5 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,22 @@ def __repr__(self):
class RadialEmbeddingBlock(torch.nn.Module):
def __init__(
self,
r_max: Union[float, np.ndarray],
r_max: Union[float, np.ndarray, torch.Tensor],
num_bessel: int,
num_polynomial_cutoff: int,
radial_type: str = "bessel",
distance_transform: str = "None",
):
super().__init__()
self.elem_dept = not isinstance(r_max, float)
if self.elem_dept:
bessel_fn_r_max = 6.0 # TODO: change to element dependent
if radial_type == "bessel":
self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
self.bessel_fn = BesselBasis(r_max=bessel_fn_r_max, num_basis=num_bessel)
elif radial_type == "gaussian":
self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel)
self.bessel_fn = GaussianBasis(r_max=bessel_fn_r_max, num_basis=num_bessel)
elif radial_type == "chebyshev":
self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel)
self.bessel_fn = ChebychevBasis(r_max=bessel_fn_r_max, num_basis=num_bessel)
if distance_transform == "Agnesi":
self.distance_transform = AgnesiTransform()
elif distance_transform == "Soft":
Expand All @@ -236,7 +239,10 @@ def forward(
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
):
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
if getattr(self, "elem_dept", False):
cutoff = self.cutoff_fn(edge_lengths, node_attrs, edge_index, atomic_numbers) # [n_edges, 1]
else:
cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
if hasattr(self, "distance_transform"):
edge_lengths = self.distance_transform(
edge_lengths, node_attrs, edge_index, atomic_numbers
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
irreps_in=node_attr_irreps, irreps_out=node_feats_irreps
)
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
r_max=self.r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
radial_type=radial_type,
Expand Down
14 changes: 8 additions & 6 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,17 @@ def forward(self,
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]

r_max = self.r_max[Z_u][Z_v]
Z_u = node_atomic_numbers[sender].view(-1)
Z_v = node_atomic_numbers[receiver].view(-1)

r_max = self.r_max[Z_v, Z_u].unsqueeze(-1)

envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2)
)
return envelope * (x < r_max)

Expand Down
7 changes: 4 additions & 3 deletions multihead_config/jz_mp_config_eledp_r.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ heads:
E0s: /lustre/fsn1/projects/rech/gax/unh55hx/data/e0s.json
config_type_weights:
Default: 1.0
#avg_num_neighbors: 61.9649349317854
#mean: 0.1634233391135065
#std: 0.7735790334431056
ratio: 0.01
avg_num_neighbors: 8.301559722246491
mean: -4.473531217546253
std: 0.7996570603329817
#test_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/multihead_datasets/test/spice
#statistics_file: /lustre/fsn1/projects/rech/gax/unh55hx/data/statistics.json
2 changes: 1 addition & 1 deletion scripts_tuning_for_md/run_multihead_5arg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ mace_run_train \
--save_cpu \
--config="multihead_config/${CONF}" \
--device=cuda \
--num_workers=0 \
--num_workers=8 \
--distributed \
--agnostic_int False False \
--agnostic_con False False \
Expand Down

0 comments on commit 9a170c2

Please sign in to comment.