-
Notifications
You must be signed in to change notification settings - Fork 198
Open
Labels
Description
Hello,
I can't compile any model that includes scatter or scatter min from torch_scatter.
For example in this beautiful script
import torch
import torch_geometric
from torch_scatter import scatter_min
print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)
def get_x(n_points=100):
import torch
x_min = [0, 10]
y_min = [0, 10]
z_min = [0, 10]
x = torch.rand((n_points, 3))
x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]
return x
device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))
model = scatter_min
compiled_model = torch.compile(model)
expected `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)
The code fails with :
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
from user code:
line 65, in scatter_min
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,