Skip to content

top_k_top_p_sampling_from_logits incompatible with torch.compile + CUDAGraph #978

@sharvil

Description

@sharvil

Python version: 3.10
PyTorch version: 2.6.0
FlashInfer version: 0.2.3+cu124torch2.6

Minimal repro:

import torch

from flashinfer.sampling import top_k_top_p_sampling_from_logits


def sample(x):
  return top_k_top_p_sampling_from_logits(
    logits=x,
    top_k=0,
    top_p=1.0,
  )

sample = torch.compile(sample, mode='reduce-overhead', fullgraph=True)
sample(torch.randn([1, 32]).cuda())

Produces the following error:

torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable(<class 'types.SimpleNamespace'>) [] {'sampling_from_probs': NestedUserFunctionVariable(), 'top_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_k_sampling_from_probs': NestedUserFunctionVariable(), 'min_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_k_top_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_p_renorm_probs': NestedUserFunctionVariable(), 'top_k_renorm_probs': NestedUserFunctionVariable(), 'top_k_mask_logits': NestedUserFunctionVariable(), 'chain_speculative_sampling': NestedUserFunctionVariable()}

from user code:
   File ".../compile.py", line 7, in sample
    return top_k_top_p_sampling_from_logits(
  File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 810, in top_k_top_p_sampling_from_logits
    masked_logits = top_k_mask_logits(logits, top_k)
  File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 1129, in top_k_mask_logits
    return get_sampling_module().top_k_mask_logits(
  File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 393, in get_sampling_module
    _sampling_module = SimpleNamespace(

Setting fullgraph=False succeeds.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions