Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove torch scatter dep #88

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .devcontainer/devcontainer.json

This file was deleted.

1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
uv pip install .[test] --system

- name: Run Tests
Expand Down
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,7 @@ The aim of `aviary` is to contain multiple models for materials discovery under

## Installation

Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with

```sh
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html
```

Make sure you replace `2.2.1` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable.

Then install `aviary` from source with
Users can install `aviary` from source with

```sh
pip install -U git+https://github.com/CompRhys/aviary
Expand Down
6 changes: 3 additions & 3 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_mean

from aviary.core import BaseModelClass
from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -114,7 +114,7 @@ def forward(
"""
atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)

crys_fea = scatter_mean(atom_fea, crystal_atom_idx, dim=0)
crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean")

# NOTE required to match the reference implementation
crys_fea = nn.functional.softplus(crys_fea)
Expand Down Expand Up @@ -236,7 +236,7 @@ def forward(

# take the elementwise product of the filter and core
nbr_msg = filter_fea * core_fea
nbr_summed = scatter_add(nbr_msg, self_idx, dim=0)
nbr_summed = scatter_reduce(nbr_msg, self_idx, dim=0, reduce="sum")

nbr_summed = self.bn2(nbr_summed)
return self.softplus2(atom_in_fea + nbr_summed)
76 changes: 76 additions & 0 deletions aviary/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch


def scatter_reduce(src, index, dim=-1, dim_size=None, reduce="sum"):
"""Performs a scatter-reduce operation on the input tensor.

This function scatters the elements from the source tensor (src) into a new tensor
of shape determined by dim_size along the specified dimension (dim), using the
given reduction method. It's compatible with autograd for gradient computation.

NOTE this function was written by Claude 3.5 Sonnet.

Args:
src (torch.Tensor): The source tensor.
index (torch.Tensor): The indices of elements to scatter. Must be 1D or have
the same number of dimensions as src.
dim (int, optional): The axis along which to index. Defaults to -1.
dim_size (int, optional): The size of the output tensor's dimension `dim`.
If None, it's inferred as index.max().item() + 1. Defaults to None.
reduce (str, optional): The reduction operation to perform.
Options: "sum", "mean", "amax", "max", "amin", "min", "prod".
Defaults to "sum".

Returns:
torch.Tensor: The output tensor after the scatter-reduce operation.

Raises:
ValueError: If an unsupported reduction method is specified.
RuntimeError: If index and src tensors are incompatible.

Example:
>>> src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
>>> index = torch.tensor([0, 1, 0, 1, 2])
>>> scatter_reduce(src, index, dim=0, reduce="sum")
tensor([4., 6., 5.])
"""
if dim_size is None:
dim_size = index.max().item() + 1

# Prepare the output tensor shape
shape = list(src.shape)
shape[dim] = dim_size

# Ensure index has the same number of dimensions as src
if index.dim() != src.dim():
if index.dim() != 1:
raise RuntimeError(
"Index tensor must be 1D or have the same number of dimensions "
f"as src tensor. {index.shape=} != {src.shape=}"
)
# Expand index to match src dimensions
repeat_shape = [1] * src.dim()
repeat_shape[dim] = src.size(dim)
index = index.view(-1, *[1] * (src.dim() - 1)).expand_as(src)

# Perform scatter_reduce operation
if reduce in ["sum", "mean"]:
out = torch.zeros(shape, dtype=src.dtype, device=src.device)
out = out.scatter_add(dim, index, src)
if reduce == "mean":
count = torch.zeros(shape, dtype=src.dtype, device=src.device)
count = count.scatter_add(dim, index, torch.ones_like(src))
out = out / (count + (count == 0).float()) # avoid division by zero
elif reduce in ["amax", "max"]:
out = torch.full(shape, float("-inf"), dtype=src.dtype, device=src.device)
out = torch.max(out, out.scatter(dim, index, src))
elif reduce in ["amin", "min"]:
out = torch.full(shape, float("inf"), dtype=src.dtype, device=src.device)
out = torch.min(out, out.scatter(dim, index, src))
elif reduce == "prod":
out = torch.ones(shape, dtype=src.dtype, device=src.device)
out = out.scatter(dim, index, src, reduce="multiply")
else:
raise ValueError(f"Unsupported reduction method: {reduce}")

return out
14 changes: 7 additions & 7 deletions aviary/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_add, scatter_max

from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -38,12 +38,12 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate -= scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index]
gate = gate.exp()
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
return scatter_reduce(gate * x, index, dim=0, reduce="sum")

def __repr__(self) -> str:
gate_nn, message_nn = self.gate_nn, self.message_nn
Expand Down Expand Up @@ -78,12 +78,12 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor:
"""
gate = self.gate_nn(x)

gate -= scatter_max(gate, index, dim=0)[0][index]
gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index]
gate = (weights**self.pow) * gate.exp()
gate /= scatter_add(gate, index, dim=0)[index] + 1e-10
gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10

x = self.message_nn(x)
return scatter_add(gate * x, index, dim=0)
return scatter_reduce(gate * x, index, dim=0, reduce="sum")

def __repr__(self) -> str:
pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn
Expand Down
6 changes: 4 additions & 2 deletions aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn.functional as F
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn
from torch_scatter import scatter_mean

from aviary.core import BaseModelClass
from aviary.networks import ResidualNetwork, SimpleNetwork
from aviary.scatter import scatter_reduce
from aviary.segments import MessageLayer, WeightedAttentionPooling

if TYPE_CHECKING:
Expand Down Expand Up @@ -261,7 +261,9 @@ def forward(
for attnhead in self.cry_pool
]

return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0)
return scatter_reduce(
torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0, reduce="mean"
)

def __repr__(self) -> str:
return (
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/Roost.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"\n",
"print(f\"{TORCH_VERSION=}\")\n",
"\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n",
"!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n",
"!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997"
]
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/Wren.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"\n",
"print(f\"{TORCH_VERSION=}\")\n",
"\n",
"!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n",
"!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n",
"!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997"
]
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "aviary"
version = "1.0.0"
version = "1.1.0"
description = "A collection of machine learning models for materials discovery"
authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }]
readme = "README.md"
Expand All @@ -27,7 +27,6 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Chemistry",
Expand All @@ -42,7 +41,6 @@ dependencies = [
"scikit_learn",
"tensorboard",
"torch",
"torch_scatter",
"tqdm",
"wandb",
]
Expand Down Expand Up @@ -114,6 +112,7 @@ ignore = [
"D105", # Missing docstring in magic method
"D205", # 1 blank line required between summary line and description
"E731", # Do not assign a lambda expression, use a def
"ISC001",
"PD901", # pandas-df-variable-name
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
Expand Down