diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index c6d93f06..00000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "image": "mcr.microsoft.com/devcontainers/universal:2", - "waitFor": "onCreateCommand", - "updateContentCommand": "pip install torch==2.1.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu && pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cpu.html && pip install -e .", - "customizations": { - "codespaces": { - "openFiles": ["examples/notebooks/wren-example.ipynb"] - }, - "vscode": { - "extensions": ["ms-toolsai.jupyter", "ms-python.python"] - } - } -} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0dbacb40..5254cb4f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,6 @@ jobs: - name: Install dependencies run: | pip install torch==2.2.1 --index-url https://download.pytorch.org/whl/cpu - pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html uv pip install .[test] --system - name: Run Tests diff --git a/README.md b/README.md index 483e77dc..4a02fc2b 100644 --- a/README.md +++ b/README.md @@ -15,15 +15,7 @@ The aim of `aviary` is to contain multiple models for materials discovery under ## Installation -Aviary requires [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter). `pip install` it with - -```sh -pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cpu.html -``` - -Make sure you replace `2.2.1` with your actual `torch.__version__` (`python -c 'import torch; print(torch.__version__)'`) and `cpu` with your CUDA version if applicable. - -Then install `aviary` from source with +Users can install `aviary` from source with ```sh pip install -U git+https://github.com/CompRhys/aviary diff --git a/aviary/cgcnn/model.py b/aviary/cgcnn/model.py index fa687718..196cabe8 100644 --- a/aviary/cgcnn/model.py +++ b/aviary/cgcnn/model.py @@ -6,10 +6,10 @@ import torch.nn.functional as F from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn -from torch_scatter import scatter_add, scatter_mean from aviary.core import BaseModelClass from aviary.networks import SimpleNetwork +from aviary.scatter import scatter_reduce if TYPE_CHECKING: from collections.abc import Sequence @@ -114,7 +114,7 @@ def forward( """ atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) - crys_fea = scatter_mean(atom_fea, crystal_atom_idx, dim=0) + crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean") # NOTE required to match the reference implementation crys_fea = nn.functional.softplus(crys_fea) @@ -236,7 +236,7 @@ def forward( # take the elementwise product of the filter and core nbr_msg = filter_fea * core_fea - nbr_summed = scatter_add(nbr_msg, self_idx, dim=0) + nbr_summed = scatter_reduce(nbr_msg, self_idx, dim=0, reduce="sum") nbr_summed = self.bn2(nbr_summed) return self.softplus2(atom_in_fea + nbr_summed) diff --git a/aviary/scatter.py b/aviary/scatter.py new file mode 100644 index 00000000..9218f001 --- /dev/null +++ b/aviary/scatter.py @@ -0,0 +1,76 @@ +import torch + + +def scatter_reduce(src, index, dim=-1, dim_size=None, reduce="sum"): + """Performs a scatter-reduce operation on the input tensor. + + This function scatters the elements from the source tensor (src) into a new tensor + of shape determined by dim_size along the specified dimension (dim), using the + given reduction method. It's compatible with autograd for gradient computation. + + NOTE this function was written by Claude 3.5 Sonnet. + + Args: + src (torch.Tensor): The source tensor. + index (torch.Tensor): The indices of elements to scatter. Must be 1D or have + the same number of dimensions as src. + dim (int, optional): The axis along which to index. Defaults to -1. + dim_size (int, optional): The size of the output tensor's dimension `dim`. + If None, it's inferred as index.max().item() + 1. Defaults to None. + reduce (str, optional): The reduction operation to perform. + Options: "sum", "mean", "amax", "max", "amin", "min", "prod". + Defaults to "sum". + + Returns: + torch.Tensor: The output tensor after the scatter-reduce operation. + + Raises: + ValueError: If an unsupported reduction method is specified. + RuntimeError: If index and src tensors are incompatible. + + Example: + >>> src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> index = torch.tensor([0, 1, 0, 1, 2]) + >>> scatter_reduce(src, index, dim=0, reduce="sum") + tensor([4., 6., 5.]) + """ + if dim_size is None: + dim_size = index.max().item() + 1 + + # Prepare the output tensor shape + shape = list(src.shape) + shape[dim] = dim_size + + # Ensure index has the same number of dimensions as src + if index.dim() != src.dim(): + if index.dim() != 1: + raise RuntimeError( + "Index tensor must be 1D or have the same number of dimensions " + f"as src tensor. {index.shape=} != {src.shape=}" + ) + # Expand index to match src dimensions + repeat_shape = [1] * src.dim() + repeat_shape[dim] = src.size(dim) + index = index.view(-1, *[1] * (src.dim() - 1)).expand_as(src) + + # Perform scatter_reduce operation + if reduce in ["sum", "mean"]: + out = torch.zeros(shape, dtype=src.dtype, device=src.device) + out = out.scatter_add(dim, index, src) + if reduce == "mean": + count = torch.zeros(shape, dtype=src.dtype, device=src.device) + count = count.scatter_add(dim, index, torch.ones_like(src)) + out = out / (count + (count == 0).float()) # avoid division by zero + elif reduce in ["amax", "max"]: + out = torch.full(shape, float("-inf"), dtype=src.dtype, device=src.device) + out = torch.max(out, out.scatter(dim, index, src)) + elif reduce in ["amin", "min"]: + out = torch.full(shape, float("inf"), dtype=src.dtype, device=src.device) + out = torch.min(out, out.scatter(dim, index, src)) + elif reduce == "prod": + out = torch.ones(shape, dtype=src.dtype, device=src.device) + out = out.scatter(dim, index, src, reduce="multiply") + else: + raise ValueError(f"Unsupported reduction method: {reduce}") + + return out diff --git a/aviary/segments.py b/aviary/segments.py index 66c78190..b101b3ef 100644 --- a/aviary/segments.py +++ b/aviary/segments.py @@ -4,9 +4,9 @@ import torch from torch import LongTensor, Tensor, nn -from torch_scatter import scatter_add, scatter_max from aviary.networks import SimpleNetwork +from aviary.scatter import scatter_reduce if TYPE_CHECKING: from collections.abc import Sequence @@ -38,12 +38,12 @@ def forward(self, x: Tensor, index: Tensor) -> Tensor: """ gate = self.gate_nn(x) - gate -= scatter_max(gate, index, dim=0)[0][index] + gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] gate = gate.exp() - gate /= scatter_add(gate, index, dim=0)[index] + 1e-10 + gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 x = self.message_nn(x) - return scatter_add(gate * x, index, dim=0) + return scatter_reduce(gate * x, index, dim=0, reduce="sum") def __repr__(self) -> str: gate_nn, message_nn = self.gate_nn, self.message_nn @@ -78,12 +78,12 @@ def forward(self, x: Tensor, index: Tensor, weights: Tensor) -> Tensor: """ gate = self.gate_nn(x) - gate -= scatter_max(gate, index, dim=0)[0][index] + gate -= scatter_reduce(gate, index, dim=0, reduce="amax")[index] gate = (weights**self.pow) * gate.exp() - gate /= scatter_add(gate, index, dim=0)[index] + 1e-10 + gate /= scatter_reduce(gate, index, dim=0, reduce="sum")[index] + 1e-10 x = self.message_nn(x) - return scatter_add(gate * x, index, dim=0) + return scatter_reduce(gate * x, index, dim=0, reduce="sum") def __repr__(self) -> str: pow, gate_nn, message_nn = float(self.pow), self.gate_nn, self.message_nn diff --git a/aviary/wren/model.py b/aviary/wren/model.py index a4b33b9d..f17764b1 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -6,10 +6,10 @@ import torch.nn.functional as F from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn -from torch_scatter import scatter_mean from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork +from aviary.scatter import scatter_reduce from aviary.segments import MessageLayer, WeightedAttentionPooling if TYPE_CHECKING: @@ -261,7 +261,9 @@ def forward( for attnhead in self.cry_pool ] - return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0) + return scatter_reduce( + torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0, reduce="mean" + ) def __repr__(self) -> str: return ( diff --git a/examples/notebooks/Roost.ipynb b/examples/notebooks/Roost.ipynb index 4b3941c6..7f53ec1b 100644 --- a/examples/notebooks/Roost.ipynb +++ b/examples/notebooks/Roost.ipynb @@ -12,7 +12,6 @@ "\n", "print(f\"{TORCH_VERSION=}\")\n", "\n", - "!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n", "!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n", "!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997" ] diff --git a/examples/notebooks/Wren.ipynb b/examples/notebooks/Wren.ipynb index 43900b53..5be4cfa2 100644 --- a/examples/notebooks/Wren.ipynb +++ b/examples/notebooks/Wren.ipynb @@ -12,7 +12,6 @@ "\n", "print(f\"{TORCH_VERSION=}\")\n", "\n", - "!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH_VERSION}.html\n", "!pip install -U git+https://github.com/CompRhys/aviary.git # install aviary\n", "!wget -O taata.json.gz https://figshare.com/ndownloader/files/34423997" ] diff --git a/pyproject.toml b/pyproject.toml index fa0e42d4..fa6be25e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aviary" -version = "1.0.0" +version = "1.1.0" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" @@ -27,7 +27,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Chemistry", @@ -42,7 +41,6 @@ dependencies = [ "scikit_learn", "tensorboard", "torch", - "torch_scatter", "tqdm", "wandb", ] @@ -114,6 +112,7 @@ ignore = [ "D105", # Missing docstring in magic method "D205", # 1 blank line required between summary line and description "E731", # Do not assign a lambda expression, use a def + "ISC001", "PD901", # pandas-df-variable-name "PLR", # pylint refactor "PT006", # pytest-parametrize-names-wrong-type