Skip to content

[Bug] FP8 scaling factors (k_scale/v_scale) not taking effect in BatchPrefillWithPagedKVCacheWrapper #1023

Open
@cscyuge

Description

@cscyuge

Description

When using FP8 data type in BatchPrefillWithPagedKVCacheWrapper, the scaling factors (k_scale and v_scale) passed to forward_return_lse do not take effect. The output remains the same regardless of the scale values used.

Reproduction Steps

Here's a minimal reproduction script:

import torch
import flashinfer
import random

# set random seed
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# init params
fp8_max = 448
num_qo_heads = 64
num_kv_heads = 16
head_dim = 128
max_num_pages = 128
page_size = 1
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")

prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD", backend="fa2"
)

batch_size = 1
nnz_qo = 100
qo_indptr = torch.tensor(
    [0, nnz_qo], dtype=torch.int32, device="cuda:0"
)
paged_kv_indices = torch.arange(max_num_pages).int().to("cuda:0")
paged_kv_indptr = torch.tensor(
    [0, 128], dtype=torch.int32, device="cuda:0"
)
paged_kv_last_page_len = torch.tensor(
    [1, 16], dtype=torch.int32, device="cuda:0"
)
q = torch.randn(nnz_qo, num_qo_heads, head_dim, dtype=torch.bfloat16, device="cuda:0")
kv_cache = torch.randn(
    max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0"
)

kv_scale = torch.max(torch.abs(kv_cache)) / fp8_max
print(f"kv_scale: {kv_scale}")
kv_cache_fp8 = kv_cache / kv_scale
kv_cache_fp8 = kv_cache_fp8.to(torch.float8_e4m3fn)

# run bfloat16
prefill_wrapper.plan(
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    q_data_type=torch.bfloat16,
    kv_data_type=torch.bfloat16,
    custom_mask=None,
)

o, lse = prefill_wrapper.forward_return_lse(
    q, 
    kv_cache,
    causal=False,
)
print(f"use bfloat16: ")
print(o[0])
print(lse[0])
print("-"*100)

# run fp8
prefill_wrapper.plan(
    qo_indptr,
    paged_kv_indptr,
    paged_kv_indices,
    paged_kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    kv_data_type=torch.float8_e4m3fn,
    q_data_type=torch.bfloat16,
    custom_mask=None,
)

o, lse = prefill_wrapper.forward_return_lse(
    q, 
    kv_cache_fp8, 
    k_scale=kv_scale, 
    v_scale=kv_scale,
    causal=False,
)
print(f"use fp8, k_scale={kv_scale}, v_scale={kv_scale}: ")
print(o[0])
print(lse[0])   
print("-"*100)

o, lse = prefill_wrapper.forward_return_lse(
    q, 
    kv_cache_fp8, 
    k_scale=1.0, 
    v_scale=1.0,
    causal=False,
)
print(f"use fp8, k_scale=1.0, v_scale=1.0: ")
print(o[0])
print(lse[0])
print("-"*100)

Expected Behavior

The outputs should be different when using different scale values, as the FP8 values need to be properly scaled back to their original range.

Actual Behavior

The outputs are identical regardless of the scale values used, indicating that the scaling factors are not being applied in the computation.

console output:

kv_scale: 0.01092529296875
use bfloat16: 
tensor([[ 0.0679,  0.0728,  0.0118,  ..., -0.1309,  0.0581,  0.2637],
        [ 0.3789,  0.0155,  0.2295,  ..., -0.0938,  0.2441,  0.1162],
        [ 0.1133, -0.0757,  0.0430,  ..., -0.1875, -0.2090,  0.1914],
        ...,
        [ 0.0649, -0.2812, -0.3047,  ..., -0.0200, -0.0131,  0.1895],
        [-0.1504, -0.3809, -0.0679,  ...,  0.0188,  0.0962,  0.0435],
        [-0.0928, -0.2637, -0.2109,  ...,  0.1455, -0.1650,  0.0356]],
       device='cuda:0', dtype=torch.bfloat16)
tensor([7.5064, 7.8416, 8.2345, 7.4729, 7.7151, 7.6594, 7.4227, 7.8959, 7.9379,
        7.6514, 7.5364, 7.7540, 7.5418, 8.0630, 7.7950, 7.6326, 7.9143, 7.4227,
        7.6645, 7.6050, 7.4975, 7.5238, 7.4255, 7.8637, 7.5372, 7.7846, 7.6967,
        7.3960, 7.3593, 7.9669, 7.7649, 7.4434, 7.5748, 7.6071, 7.7232, 7.4752,
        7.8065, 7.7136, 7.6856, 7.7621, 8.1139, 7.6196, 7.8465, 7.9204, 8.1066,
        7.4417, 7.4475, 7.7789, 7.9261, 7.8857, 7.8761, 7.8458, 7.4534, 7.7438,
        7.7164, 7.6664, 7.4687, 7.8115, 7.9794, 7.6866, 7.9497, 7.4951, 7.5980,
        7.4584], device='cuda:0')
----------------------------------------------------------------------------------------------------
use fp8, k_scale=0.01092529296875, v_scale=0.01092529296875: 
tensor([[  13.0000,  128.0000,  104.0000,  ...,  -12.0000,  -48.0000,
          -36.0000],
        [  15.0000,   40.0000,  176.0000,  ...,   24.0000,  120.0000,
         -112.0000],
        [ 104.0000,  -96.0000,  112.0000,  ..., -112.0000, -160.0000,
           20.0000],
        ...,
        [  35.5000, -127.0000,  -98.0000,  ...,   -6.4062,   28.6250,
          180.0000],
        [ -88.0000, -128.0000,   96.0000,  ...,  -72.0000,   48.0000,
           28.0000],
        [  30.0000, -144.0000, -120.0000,  ...,  -20.0000,   28.0000,
          192.0000]], device='cuda:0', dtype=torch.bfloat16)
tensor([255.6564, 456.0211, 421.9968, 351.9308, 354.5456, 346.6835, 290.8127,
        436.7100, 426.3302, 297.2573, 335.6914, 384.4396, 261.8149, 406.5260,
        453.6687, 362.3043, 363.5225, 236.3233, 354.5876, 290.7133, 317.7333,
        239.1209, 236.5329, 421.9930, 275.4621, 341.4602, 249.7285, 258.6279,
        251.2649, 416.0476, 352.6308, 267.1920, 291.0104, 286.1772, 288.0804,
        283.1631, 347.4054, 295.9063, 348.5504, 362.9382, 505.1582, 335.1806,
        337.2320, 377.6111, 538.7235, 283.2250, 292.1403, 345.3987, 394.2041,
        507.5838, 300.3797, 304.4101, 315.5672, 504.2701, 321.3320, 246.9149,
        267.9784, 333.2326, 428.1322, 321.7399, 356.3152, 286.4862, 366.9317,
        262.6851], device='cuda:0')
----------------------------------------------------------------------------------------------------
use fp8, k_scale=1.0, v_scale=1.0: 
tensor([[  13.0000,  128.0000,  104.0000,  ...,  -12.0000,  -48.0000,
          -36.0000],
        [  15.0000,   40.0000,  176.0000,  ...,   24.0000,  120.0000,
         -112.0000],
        [ 104.0000,  -96.0000,  112.0000,  ..., -112.0000, -160.0000,
           20.0000],
        ...,
        [  35.5000, -127.0000,  -98.0000,  ...,   -6.4062,   28.6250,
          180.0000],
        [ -88.0000, -128.0000,   96.0000,  ...,  -72.0000,   48.0000,
           28.0000],
        [  30.0000, -144.0000, -120.0000,  ...,  -20.0000,   28.0000,
          192.0000]], device='cuda:0', dtype=torch.bfloat16)
tensor([255.6564, 456.0211, 421.9968, 351.9308, 354.5456, 346.6835, 290.8127,
        436.7100, 426.3302, 297.2573, 335.6914, 384.4396, 261.8149, 406.5260,
        453.6687, 362.3043, 363.5225, 236.3233, 354.5876, 290.7133, 317.7333,
        239.1209, 236.5329, 421.9930, 275.4621, 341.4602, 249.7285, 258.6279,
        251.2649, 416.0476, 352.6308, 267.1920, 291.0104, 286.1772, 288.0804,
        283.1631, 347.4054, 295.9063, 348.5504, 362.9382, 505.1582, 335.1806,
        337.2320, 377.6111, 538.7235, 283.2250, 292.1403, 345.3987, 394.2041,
        507.5838, 300.3797, 304.4101, 315.5672, 504.2701, 321.3320, 246.9149,
        267.9784, 333.2326, 428.1322, 321.7399, 356.3152, 286.4862, 366.9317,
        262.6851], device='cuda:0')
----------------------------------------------------------------------------------------------------

Environment

  • flashinfer version: 0.2.3+cu124torch2.5
  • PyTorch version: 2.5.1
  • CUDA version: 12.4
  • GPU: NVIDIA H20

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions