Open
Description
Description
When using FP8 data type in BatchPrefillWithPagedKVCacheWrapper
, the scaling factors (k_scale
and v_scale
) passed to forward_return_lse
do not take effect. The output remains the same regardless of the scale values used.
Reproduction Steps
Here's a minimal reproduction script:
import torch
import flashinfer
import random
# set random seed
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# init params
fp8_max = 448
num_qo_heads = 64
num_kv_heads = 16
head_dim = 128
max_num_pages = 128
page_size = 1
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD", backend="fa2"
)
batch_size = 1
nnz_qo = 100
qo_indptr = torch.tensor(
[0, nnz_qo], dtype=torch.int32, device="cuda:0"
)
paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
paged_kv_indptr = torch.tensor(
[0, 128], dtype=torch.int32, device="cuda:0"
)
paged_kv_last_page_len = torch.tensor(
[1, 16], dtype=torch.int32, device="cuda:0"
)
q = torch.randn(nnz_qo, num_qo_heads, head_dim, dtype=torch.bfloat16, device="cuda:0")
kv_cache = torch.randn(
max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0"
)
kv_scale = torch.max(torch.abs(kv_cache)) / fp8_max
print(f"kv_scale: {kv_scale}")
kv_cache_fp8 = kv_cache / kv_scale
kv_cache_fp8 = kv_cache_fp8.to(torch.float8_e4m3fn)
# run bfloat16
prefill_wrapper.plan(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
custom_mask=None,
)
o, lse = prefill_wrapper.forward_return_lse(
q,
kv_cache,
causal=False,
)
print(f"use bfloat16: ")
print(o[0])
print(lse[0])
print("-"*100)
# run fp8
prefill_wrapper.plan(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
kv_data_type=torch.float8_e4m3fn,
q_data_type=torch.bfloat16,
custom_mask=None,
)
o, lse = prefill_wrapper.forward_return_lse(
q,
kv_cache_fp8,
k_scale=kv_scale,
v_scale=kv_scale,
causal=False,
)
print(f"use fp8, k_scale={kv_scale}, v_scale={kv_scale}: ")
print(o[0])
print(lse[0])
print("-"*100)
o, lse = prefill_wrapper.forward_return_lse(
q,
kv_cache_fp8,
k_scale=1.0,
v_scale=1.0,
causal=False,
)
print(f"use fp8, k_scale=1.0, v_scale=1.0: ")
print(o[0])
print(lse[0])
print("-"*100)
Expected Behavior
The outputs should be different when using different scale values, as the FP8 values need to be properly scaled back to their original range.
Actual Behavior
The outputs are identical regardless of the scale values used, indicating that the scaling factors are not being applied in the computation.
console output:
kv_scale: 0.01092529296875
use bfloat16:
tensor([[ 0.0679, 0.0728, 0.0118, ..., -0.1309, 0.0581, 0.2637],
[ 0.3789, 0.0155, 0.2295, ..., -0.0938, 0.2441, 0.1162],
[ 0.1133, -0.0757, 0.0430, ..., -0.1875, -0.2090, 0.1914],
...,
[ 0.0649, -0.2812, -0.3047, ..., -0.0200, -0.0131, 0.1895],
[-0.1504, -0.3809, -0.0679, ..., 0.0188, 0.0962, 0.0435],
[-0.0928, -0.2637, -0.2109, ..., 0.1455, -0.1650, 0.0356]],
device='cuda:0', dtype=torch.bfloat16)
tensor([7.5064, 7.8416, 8.2345, 7.4729, 7.7151, 7.6594, 7.4227, 7.8959, 7.9379,
7.6514, 7.5364, 7.7540, 7.5418, 8.0630, 7.7950, 7.6326, 7.9143, 7.4227,
7.6645, 7.6050, 7.4975, 7.5238, 7.4255, 7.8637, 7.5372, 7.7846, 7.6967,
7.3960, 7.3593, 7.9669, 7.7649, 7.4434, 7.5748, 7.6071, 7.7232, 7.4752,
7.8065, 7.7136, 7.6856, 7.7621, 8.1139, 7.6196, 7.8465, 7.9204, 8.1066,
7.4417, 7.4475, 7.7789, 7.9261, 7.8857, 7.8761, 7.8458, 7.4534, 7.7438,
7.7164, 7.6664, 7.4687, 7.8115, 7.9794, 7.6866, 7.9497, 7.4951, 7.5980,
7.4584], device='cuda:0')
----------------------------------------------------------------------------------------------------
use fp8, k_scale=0.01092529296875, v_scale=0.01092529296875:
tensor([[ 13.0000, 128.0000, 104.0000, ..., -12.0000, -48.0000,
-36.0000],
[ 15.0000, 40.0000, 176.0000, ..., 24.0000, 120.0000,
-112.0000],
[ 104.0000, -96.0000, 112.0000, ..., -112.0000, -160.0000,
20.0000],
...,
[ 35.5000, -127.0000, -98.0000, ..., -6.4062, 28.6250,
180.0000],
[ -88.0000, -128.0000, 96.0000, ..., -72.0000, 48.0000,
28.0000],
[ 30.0000, -144.0000, -120.0000, ..., -20.0000, 28.0000,
192.0000]], device='cuda:0', dtype=torch.bfloat16)
tensor([255.6564, 456.0211, 421.9968, 351.9308, 354.5456, 346.6835, 290.8127,
436.7100, 426.3302, 297.2573, 335.6914, 384.4396, 261.8149, 406.5260,
453.6687, 362.3043, 363.5225, 236.3233, 354.5876, 290.7133, 317.7333,
239.1209, 236.5329, 421.9930, 275.4621, 341.4602, 249.7285, 258.6279,
251.2649, 416.0476, 352.6308, 267.1920, 291.0104, 286.1772, 288.0804,
283.1631, 347.4054, 295.9063, 348.5504, 362.9382, 505.1582, 335.1806,
337.2320, 377.6111, 538.7235, 283.2250, 292.1403, 345.3987, 394.2041,
507.5838, 300.3797, 304.4101, 315.5672, 504.2701, 321.3320, 246.9149,
267.9784, 333.2326, 428.1322, 321.7399, 356.3152, 286.4862, 366.9317,
262.6851], device='cuda:0')
----------------------------------------------------------------------------------------------------
use fp8, k_scale=1.0, v_scale=1.0:
tensor([[ 13.0000, 128.0000, 104.0000, ..., -12.0000, -48.0000,
-36.0000],
[ 15.0000, 40.0000, 176.0000, ..., 24.0000, 120.0000,
-112.0000],
[ 104.0000, -96.0000, 112.0000, ..., -112.0000, -160.0000,
20.0000],
...,
[ 35.5000, -127.0000, -98.0000, ..., -6.4062, 28.6250,
180.0000],
[ -88.0000, -128.0000, 96.0000, ..., -72.0000, 48.0000,
28.0000],
[ 30.0000, -144.0000, -120.0000, ..., -20.0000, 28.0000,
192.0000]], device='cuda:0', dtype=torch.bfloat16)
tensor([255.6564, 456.0211, 421.9968, 351.9308, 354.5456, 346.6835, 290.8127,
436.7100, 426.3302, 297.2573, 335.6914, 384.4396, 261.8149, 406.5260,
453.6687, 362.3043, 363.5225, 236.3233, 354.5876, 290.7133, 317.7333,
239.1209, 236.5329, 421.9930, 275.4621, 341.4602, 249.7285, 258.6279,
251.2649, 416.0476, 352.6308, 267.1920, 291.0104, 286.1772, 288.0804,
283.1631, 347.4054, 295.9063, 348.5504, 362.9382, 505.1582, 335.1806,
337.2320, 377.6111, 538.7235, 283.2250, 292.1403, 345.3987, 394.2041,
507.5838, 300.3797, 304.4101, 315.5672, 504.2701, 321.3320, 246.9149,
267.9784, 333.2326, 428.1322, 321.7399, 356.3152, 286.4862, 366.9317,
262.6851], device='cuda:0')
----------------------------------------------------------------------------------------------------
Environment
- flashinfer version: 0.2.3+cu124torch2.5
- PyTorch version: 2.5.1
- CUDA version: 12.4
- GPU: NVIDIA H20
Metadata
Metadata
Assignees
Labels
No labels