-
Notifications
You must be signed in to change notification settings - Fork 627
Open
Labels
Description
Python version: 3.10
PyTorch version: 2.6.0
FlashInfer version: 0.2.3+cu124torch2.6
Minimal repro:
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_logits
def sample(x):
return top_k_top_p_sampling_from_logits(
logits=x,
top_k=0,
top_p=1.0,
)
sample = torch.compile(sample, mode='reduce-overhead', fullgraph=True)
sample(torch.randn([1, 32]).cuda())Produces the following error:
torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable(<class 'types.SimpleNamespace'>) [] {'sampling_from_probs': NestedUserFunctionVariable(), 'top_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_k_sampling_from_probs': NestedUserFunctionVariable(), 'min_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_k_top_p_sampling_from_probs': NestedUserFunctionVariable(), 'top_p_renorm_probs': NestedUserFunctionVariable(), 'top_k_renorm_probs': NestedUserFunctionVariable(), 'top_k_mask_logits': NestedUserFunctionVariable(), 'chain_speculative_sampling': NestedUserFunctionVariable()}
from user code:
File ".../compile.py", line 7, in sample
return top_k_top_p_sampling_from_logits(
File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 810, in top_k_top_p_sampling_from_logits
masked_logits = top_k_mask_logits(logits, top_k)
File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 1129, in top_k_mask_logits
return get_sampling_module().top_k_mask_logits(
File "/home/sharvil/.conda/envs/lmnt/lib/python3.10/site-packages/flashinfer/sampling.py", line 393, in get_sampling_module
_sampling_module = SimpleNamespace(
Setting fullgraph=False succeeds.