Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter or scatter_min fails when using torch.compile #440

Open
gardiens opened this issue May 2, 2024 · 4 comments
Open

scatter or scatter_min fails when using torch.compile #440

gardiens opened this issue May 2, 2024 · 4 comments

Comments

@gardiens
Copy link

gardiens commented May 2, 2024

Hello,

I can't compile any model that includes scatter or scatter min from torch_scatter.
For example in this beautiful script

  import torch
import torch_geometric
from torch_scatter import scatter_min

print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)


def get_x(n_points=100):  
    import torch

    x_min = [0, 10]
    y_min = [0, 10]
    z_min = [0, 10]

    x = torch.rand((n_points, 3))
    x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
    x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
    x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]

    return x


device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))

model = scatter_min
compiled_model = torch.compile(model)

expected  `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)

The code fails with :

 torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
 line 65, in scatter_min
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)

My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,

@rusty1s
Copy link
Owner

rusty1s commented May 7, 2024

This is currently expected, since the custom ops by torch-scatter are not supported in torch.compile. There exists two options:

@rusty1s
Copy link
Owner

rusty1s commented May 7, 2024

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

@gardiens
Copy link
Author

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?

@rusty1s
Copy link
Owner

rusty1s commented May 10, 2024

Yes, if you want torch.compile support, then this is the recommended way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants