Skip to content

Commit fc62e3f

Browse files
authored
bugfix: follow user-specified sm_scale for blackwell cutlass fmha (#1072)
## 📌 Description Use user-specified instead of hardcoded sm_scale for blackwell cutlass fmha kernel. cc @nandor --- ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.).
1 parent 7bdadce commit fc62e3f

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

flashinfer/prefill.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2638,7 +2638,8 @@ def fmha_varlen(
26382638
nnz_kv, num_kv_heads, head_dim_vo = v.shape
26392639

26402640
mask_mode_code = 1 if causal else 0
2641-
sm_scale = 1.0 / math.sqrt(head_dim_qk)
2641+
if sm_scale is None:
2642+
sm_scale = 1.0 / math.sqrt(head_dim_qk)
26422643

26432644
qo_lens = qo_segment_offsets[1:] - qo_segment_offsets[:-1]
26442645
kv_lens = kv_segment_offsets[1:] - kv_segment_offsets[:-1]

include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
171171
struct Arguments {
172172
typename Load::Arguments load;
173173

174-
// if zero, defaults to 1/sqrt(D)
175-
float scale_softmax = 0.0f;
174+
float scale_softmax;
176175

177176
// scaling factors to dequantize QKV
178177
float scale_q = 1.0f;
@@ -201,9 +200,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
201200
static Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args,
202201
void* workspace) {
203202
float scale_softmax = args.scale_softmax;
204-
if (scale_softmax == 0.0f) {
205-
scale_softmax = 1.0f / (float)std::sqrt(get<2>(problem_shape));
206-
}
207203
float log2_e = static_cast<float>(std::log2(std::exp(1.0)));
208204

209205
return Params{Load::to_underlying_arguments(problem_shape, args.load, workspace),

include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct FwdRunner {
126126
typename Operation::Arguments arguments{
127127
problem_shape,
128128
{static_cast<Element*>(q.data_ptr()), layout_Q, static_cast<Element*>(k.data_ptr()),
129-
layout_K, static_cast<Element*>(v.data_ptr()), layout_V},
129+
layout_K, static_cast<Element*>(v.data_ptr()), layout_V, float(sm_scale)},
130130
{static_cast<ElementOut*>(o.data_ptr()) - max_qo_len * get<0>(stride_O), layout_O,
131131
static_cast<ElementAccumulatorPV*>(maybe_lse.value().data_ptr()), layout_LSE},
132132
hw_info};

tests/test_blackwell_fmha.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def attention_ref(
6161
@pytest.mark.parametrize("num_kv_heads", [4, 32])
6262
@pytest.mark.parametrize("head_dim_qk", [192, 128])
6363
@pytest.mark.parametrize("head_dim_vo", [128])
64+
@pytest.mark.parametrize("sm_scale", [1.0, 1.0 / math.sqrt(192), 1.0 / math.sqrt(128)])
6465
@pytest.mark.parametrize("causal", [False, True])
6566
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
6667
def test_blackwell_cutlass_fmha(
@@ -71,6 +72,7 @@ def test_blackwell_cutlass_fmha(
7172
num_kv_heads,
7273
head_dim_qk,
7374
head_dim_vo,
75+
sm_scale,
7476
causal,
7577
dtype,
7678
):
@@ -102,7 +104,6 @@ def test_blackwell_cutlass_fmha(
102104
kv_layout="NHD",
103105
backend="cutlass",
104106
)
105-
sm_scale = 1.0 / (head_dim_qk**0.5)
106107
wrapper.plan(
107108
qo_indptr,
108109
kv_indptr,
@@ -142,6 +143,7 @@ def test_blackwell_cutlass_fmha(
142143
4,
143144
192,
144145
128,
146+
1,
145147
True,
146148
torch.bfloat16,
147149
# 3,

0 commit comments

Comments
 (0)