Skip to content
21 changes: 13 additions & 8 deletions csrc/flashmask_v2/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,20 @@ void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool ena
#endif
}

inline bool is_short_seqlen(Flash_fwd_params const& params) {
return params.seqlen_k < 128 && params.seqlen_q < 128;
inline SeqlenDispatchTag get_seqlen_dispatch_tag(Flash_fwd_params const& params) {
if (params.seqlen_k < 128 && params.seqlen_q < 128) {
return SeqlenDispatchTag::TinySeq;
} else if (params.seqlen_k < 16384 && params.seqlen_q < 16384) {
return SeqlenDispatchTag::ShortSeq;
}
return SeqlenDispatchTag::LongSeq;
}

inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
// This needs to match the kernel configs
bool const short_seqlen = is_short_seqlen(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f, short_seqlen);
auto const seqlen_dispatch = get_seqlen_dispatch_tag(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f, seqlen_dispatch);
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
// Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
Expand All @@ -191,8 +196,8 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) {
// params.page_table must already be set
if (params.h == params.h_k) { return false; }
// This needs to match the kernel configs
bool const short_seqlen = is_short_seqlen(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, short_seqlen);
auto const seqlen_dispatch = get_seqlen_dispatch_tag(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, seqlen_dispatch);
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
#endif
Expand All @@ -206,8 +211,8 @@ inline int get_num_splits(Flash_fwd_params const& params) {
// params.page_table must already be set
// This needs to match the kernel configs
bool const varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
bool const short_seqlen = is_short_seqlen(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, short_seqlen);
auto const seqlen_dispatch = get_seqlen_dispatch_tag(params);
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, seqlen_dispatch);
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
Expand Down
64 changes: 27 additions & 37 deletions csrc/flashmask_v2/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using namespace cute;

template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
bool Is_flashmask, bool Has_lt_end, bool Has_ut_start, bool Is_blockmask,
bool Has_lt_end, bool Has_ut_start, bool Is_blockmask,
int Stages_dO=2, int Stages_dS_or_QSm80=2,
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
Expand Down Expand Up @@ -93,7 +93,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
Arch >= 90,
flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
SdP_swapAB, dKV_swapAB, dQ_swapAB, Is_flashmask, Has_lt_end, Has_ut_start,Is_blockmask, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
SdP_swapAB, dKV_swapAB, dQ_swapAB, Has_lt_end, Has_ut_start, Is_blockmask, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
Expand All @@ -114,9 +114,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
>;

if constexpr (Is_flashmask) {
flash::flashmask::prepare_block_maxmin<kBlockN>(params, stream);
}
flash::flashmask::prepare_block_maxmin<kBlockN>(params, stream);

if constexpr (Arch >= 90) {
prepare_preemptive_scheduler(params, stream, params.num_sm);
Expand Down Expand Up @@ -352,7 +350,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
}

template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
bool Is_flashmask_, bool Has_lt_end_, bool Has_ut_start_, bool Is_blockmask_,
bool Has_lt_end_, bool Has_ut_start_, bool Is_blockmask_,
int Stages_dO=2, int Stages_dS_or_QSm80=2,
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
Expand All @@ -362,7 +360,7 @@ void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
// BOOL_SWITCH(params.deterministic, Deterministic, [&] {
// run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Is_flashmask_, Has_lt_end_, Has_ut_start_,Is_blockmask_, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Has_lt_end_, Has_ut_start_, Is_blockmask_, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
// });
});
});
Expand All @@ -373,26 +371,22 @@ template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// printf("point2-1\n");
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if constexpr (Is_flashmask_ && !Is_causal) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
} else if constexpr (Is_causal && Has_softcap || Is_flashmask_) {
// register spill with 128 x 128
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
if constexpr (!Is_causal) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 64, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
} else {
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
// register spill with 128 x 128
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
// run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false, true>(params, stream);
}
});
});
Expand All @@ -401,38 +395,36 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, false, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false, true>(params, stream);
}
});
}

template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if constexpr (Is_causal || Is_local || Has_softcap) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
} else {
if ((params.seqlen_q >= 1024 || params.seqlen_k >= 1024) && !(Has_lt_end && Has_ut_start)) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 64, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 64, 128, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
}
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false, true>(params, stream);
}
});
});
Expand All @@ -441,38 +433,36 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if (Has_lt_end && Has_ut_start) {
run_mha_bwd_dispatch<Arch, T, 64, 48, 192, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 48, 192, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, false, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, false, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false, true>(params, stream);
}
});
}

template<int Arch, typename T, bool Has_softcap, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
static constexpr bool Is_local = false;
static constexpr bool Is_flashmask_ = true;
BOOL_SWITCH(params.block_mask_ptr != nullptr, Is_blockmask_, [&]{
FLASH_MASK_SWITCH(params.lt_end_ptr != nullptr, params.ut_start_ptr != nullptr, Has_lt_end, Has_ut_start, [&] {
if constexpr (Arch >= 90) {
if (Has_lt_end && Has_ut_start) {
run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Is_blockmask_, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, Has_lt_end, Has_ut_start, Is_blockmask_, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
}
} else if constexpr (Arch == 86 || Arch == 89) {
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true, true>(params, stream);
} else {
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false, Is_flashmask_>(params, stream);
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false, true>(params, stream);
}
});
});
Expand Down
Loading