Skip to content

Commit 7bdadce

Browse files
authored
bugfix: adding lse output to blackwell fmha kernels (#1071)
1 parent 21ea1d2 commit 7bdadce

File tree

5 files changed

+69
-41
lines changed

5 files changed

+69
-41
lines changed

include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
4949
using LayoutO = cute::Layout<ShapeT, StrideO>;
5050

5151
using ShapeLSE = cute::Shape<int32_t, cute::Shape<int32_t, int32_t>>;
52-
using StrideLSE = cute::Shape<_1, cute::Shape<int32_t, int32_t>>;
52+
using StrideLSE = cute::Shape<int32_t, cute::Shape<_1, int32_t>>;
5353
using LayoutLSE = cute::Layout<ShapeLSE, StrideLSE>;
5454

5555
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
@@ -103,6 +103,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
103103
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
104104
}
105105

106+
const Params& params;
107+
108+
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
109+
106110
template <class BlkCoord, class ProblemShape, class ParamsProblemShape>
107111
CUTLASS_DEVICE auto store(BlkCoord const& blk_coord, ProblemShape const& problem_shape,
108112
Params const& params, ParamsProblemShape const& params_problem_shape,
@@ -120,10 +124,6 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
120124
int o0_index = 2 * get<0>(blk_coord);
121125
int o1_index = 2 * get<0>(blk_coord) + 1;
122126

123-
// Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), params.layout_LSE);
124-
// Tensor gLSE = get_lse_local_tile_tensor(mLSE, Shape<Int<CTA_Q>>{}, qo_head_idx, qo_indptr,
125-
// qo_len)(_, qo_tile_idx);
126-
127127
int max_length_q = get<0>(params_problem_shape).max_length;
128128
int offs_0 = max_length_q - qo_len;
129129
int offs_2_1 = qo_segment_offset + qo_len;

include/flashinfer/attention/blackwell/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,14 +906,17 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
906906
}
907907
}
908908

909-
template <class BlkCoord, class ProblemShape, class TensorStorageEpi>
909+
template <class BlkCoord, class ParamsProblemShape, class ProblemShape, class TensorStorageEpi,
910+
class CollectiveEpilogue>
910911
CUTLASS_DEVICE auto correction(
911-
BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape,
912+
BlkCoord const& blk_coord, Params const& params,
913+
ParamsProblemShape const& params_problem_shape, ProblemShape const& problem_shape,
912914
TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c,
913915
typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c,
914916
typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o,
915917
typename PipelineO::PipelineState& pipeline_o_consumer_state, PipelineE& pipeline_epi,
916-
typename PipelineE::PipelineState& pipeline_epi_producer_state) {
918+
typename PipelineE::PipelineState& pipeline_epi_producer_state,
919+
CollectiveEpilogue& epilogue) {
917920
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
918921

919922
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
@@ -1024,7 +1027,23 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
10241027
// store to smem
10251028
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()),
10261029
typename TensorStorageEpi::SmemLayoutO{});
1030+
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), epilogue.params.layout_LSE);
10271031
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
1032+
if (epilogue.params.ptr_LSE != nullptr) {
1033+
int qo_tile_idx = get<0>(blk_coord);
1034+
int qo_head_idx = get<2, 0>(blk_coord);
1035+
int batch_idx = get<2, 1>(blk_coord);
1036+
int qo_len = get<0>(problem_shape);
1037+
int segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx];
1038+
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * qo_tile_idx;
1039+
1040+
ElementPV lse = __log2f(tTMEM_LOADVrS(kIdxFinalRowSum)) +
1041+
params.scale_softmax_log2 * tTMEM_LOADVrS(kIdxFinalRowMax);
1042+
1043+
if (row_idx < qo_len) {
1044+
gLSE(segment_offset + row_idx, qo_head_idx) = lse;
1045+
}
1046+
}
10281047
// correction_epilogue(params.scale_output, _0{}, sO);
10291048

10301049
cutlass::arch::fence_view_async_tmem_load();
@@ -1047,6 +1066,23 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
10471066
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
10481067

10491068
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
1069+
1070+
if (epilogue.params.ptr_LSE != nullptr) {
1071+
int qo_tile_idx = get<0>(blk_coord);
1072+
int qo_head_idx = get<2, 0>(blk_coord);
1073+
int batch_idx = get<2, 1>(blk_coord);
1074+
int qo_len = get<0>(problem_shape);
1075+
int segment_offset = get<0>(params_problem_shape).segment_offsets[batch_idx];
1076+
int row_idx =
1077+
get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * qo_tile_idx + get<0>(TileShapeQK{});
1078+
1079+
ElementPV lse = __log2f(tTMEM_LOADVrS(kIdxFinalRowSum)) +
1080+
params.scale_softmax_log2 * tTMEM_LOADVrS(kIdxFinalRowMax);
1081+
1082+
if (row_idx < qo_len) {
1083+
gLSE(segment_offset + row_idx, qo_head_idx) = lse;
1084+
}
1085+
}
10501086
// correction_epilogue(params.scale_output, _1{}, sO);
10511087
cutlass::arch::fence_view_async_tmem_load();
10521088

include/flashinfer/attention/blackwell/fmha_cutlass_sm100.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct FwdRunner {
5656
// NOTE(Zihao): use markus's trick for tma store
5757
using StrideO =
5858
cute::tuple<int, _1, cute::tuple<cute::tuple<int, int>, int>>; // Q D (H_G H_R) CUMULATIVE_Q
59-
using StrideLSE = cute::tuple<_1, cute::tuple<int, int>>; // Q (H_G H_R)
59+
using StrideLSE = cute::tuple<int, cute::tuple<_1, int>>; // Q (H_G H_R)
6060

6161
using Mainloop = cutlass::fmha::collective::Sm100FmhaFwdMainloopTmaWarpspecialized<
6262
Element, ElementAccumulatorQK, ElementAccumulatorPV, TileShapeQK, TileShapePV, StrideQ,
@@ -108,7 +108,7 @@ struct FwdRunner {
108108
make_stride(make_stride(head_dim_vo, h_r * head_dim_vo), num_qo_heads * head_dim_vo));
109109
stride_K = make_stride(num_kv_heads * head_dim_qk, _1{}, make_stride(_0{}, head_dim_qk));
110110
stride_V = make_stride(_1{}, num_kv_heads * head_dim_vo, make_stride(_0{}, head_dim_vo));
111-
stride_LSE = make_stride(_1{}, make_stride(total_qo_len, total_qo_len * h_r));
111+
stride_LSE = make_stride(num_qo_heads, make_stride(_1{}, h_r));
112112

113113
auto shape_Q = make_shape(total_qo_len, head_dim_qk, make_shape(h_r, num_kv_heads));
114114
auto shape_O = make_shape(max_qo_len, head_dim_vo,

include/flashinfer/attention/blackwell/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
360360
cutlass::make_producer_start_state<typename CollectiveMainloop::PipelineO>();
361361

362362
CollectiveMainloop mainloop;
363-
CollectiveEpilogue epilogue;
363+
CollectiveEpilogue epilogue{params.epilogue};
364364

365365
if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) {
366366
warpgroup_reg_set<NumRegsSoftmax>();
@@ -400,11 +400,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
400400
continue;
401401
}
402402

403-
mainloop.correction(
404-
blk_coord, params.mainloop, logical_problem_shape, shared_storage.epilogue,
405-
pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr,
406-
pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state,
407-
pipeline_corr_epi, pipeline_corr_epi_producer_state);
403+
mainloop.correction(blk_coord, params.mainloop, params.problem_shape, logical_problem_shape,
404+
shared_storage.epilogue, pipeline_s0_corr,
405+
pipeline_s0_corr_consumer_state, pipeline_s1_corr,
406+
pipeline_s1_corr_consumer_state, pipeline_mma_corr,
407+
pipeline_mma_corr_consumer_state, pipeline_corr_epi,
408+
pipeline_corr_epi_producer_state, epilogue);
408409
}
409410

410411
if constexpr (NumWarpsEpilogue == 0) {

tests/test_blackwell_fmha.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def attention_ref(
6262
@pytest.mark.parametrize("head_dim_qk", [192, 128])
6363
@pytest.mark.parametrize("head_dim_vo", [128])
6464
@pytest.mark.parametrize("causal", [False, True])
65-
@pytest.mark.parametrize("dtype", [torch.half])
65+
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
6666
def test_blackwell_cutlass_fmha(
6767
batch_size,
6868
qo_len,
@@ -117,41 +117,32 @@ def test_blackwell_cutlass_fmha(
117117
)
118118
o, lse = wrapper.run(q, k, v, return_lse=True)
119119

120-
# gqa_group_ratio = num_qo_heads // num_kv_heads
121-
# k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
122-
# v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
123-
# o_ref, lse_ref = attention_ref(
124-
# batch_size, q, k_repeated, v_repeated, causal, sm_scale
125-
# )
126-
127-
# lse_ref = lse_ref.flatten(0, 1)
128-
# if dtype == torch.half:
129-
# torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
130-
# else:
131-
# torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
120+
gqa_group_ratio = num_qo_heads // num_kv_heads
121+
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
122+
v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
123+
o_ref, lse_ref = attention_ref(
124+
batch_size, q, k_repeated, v_repeated, causal, sm_scale
125+
)
132126

133-
# torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
127+
lse_ref = lse_ref.flatten(0, 1)
128+
if dtype == torch.half:
129+
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
130+
else:
131+
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
134132

135-
# test with pre-allocated output
136-
# o_buffer = torch.empty_like(o)
137-
# lse_buffer = torch.empty_like(lse)
138-
# flashinfer.prefill.fmha(
139-
# q, k, v, qo_lens, kv_lens, out=o_buffer, lse=lse_buffer, causal=causal
140-
# )
141-
# torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
142-
# torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3)
133+
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
143134

144135

145136
if __name__ == "__main__":
146137
test_blackwell_cutlass_fmha(
147138
1,
148-
1,
149-
1,
150139
32,
140+
32,
141+
4,
151142
4,
152143
192,
153144
128,
154-
False,
145+
True,
155146
torch.bfloat16,
156147
# 3,
157148
# 999,

0 commit comments

Comments
 (0)