Skip to content

permute different with torch version #1323

Description

@hanlinxuy
import torch
from permute_unpermute import cuda_token_permute, cuda_token_unpermute, cuda_token_permute_torch, cuda_token_unpermute_torch

torch.manual_seed(1)
device = torch.device("cuda")
hidden_states = torch.randn(16, 64, device=device).to(torch.bfloat16)
router_logits = torch.randn(16, 32, device=device).to(torch.bfloat16)
top_k = 4

router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1)  # (seq_len, top_k)

hidden_states1, row_id_map1 = cuda_token_permute_torch(hidden_states, router_indices)
hidden_states2, row_id_map2 = cuda_token_permute(hidden_states, router_indices)

print(row_id_map1)
print(row_id_map2)

this is a simple script but I got different row_id_map results, could some help me to understand why this happen?

code were ran with L40 and cu128

enviroment are below.

uv pip list
Package                  Version   Editable project location
------------------------ --------- ------------------------------------------------------------------------
absl-py                  2.3.1
filelock                 3.20.0
fsspec                   2025.10.0
grouped-gemm             1.1.4
jinja2                   3.1.6
markupsafe               3.0.3
mpmath                   1.3.0
networkx                 3.6
numpy                    2.3.5
nvidia-cublas-cu12       12.8.4.1
nvidia-cuda-cupti-cu12   12.8.90
nvidia-cuda-nvrtc-cu12   12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12        9.10.2.21
nvidia-cufft-cu12        11.3.3.83
nvidia-cufile-cu12       1.13.1.3
nvidia-curand-cu12       10.3.9.90
nvidia-cusolver-cu12     11.7.3.90
nvidia-cusparse-cu12     12.5.8.93
nvidia-cusparselt-cu12   0.7.1
nvidia-nccl-cu12         2.27.5
nvidia-nvjitlink-cu12    12.8.93
nvidia-nvshmem-cu12      3.3.20
nvidia-nvtx-cu12         12.8.90
setuptools               80.9.0
sympy                    1.14.0
torch                    2.9.1
triton                   3.5.1
typing-extensions        4.15.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions