```bash pip install torch_scatter ``` The version of `torch_scatter` is `2.1.2` ```python import torch from torch_scatter import scatter_max src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) * -1 index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6)) out, argmax = scatter_max(src, index, out=out) print(out) print(argmax) ``` The result of `out` is: ```python tensor([[0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]) ```