Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions csrc/flashmask_v2/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,32 @@ class FlashAttnFwdSm90 {
// static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);

static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);
// static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);

static constexpr int kHeadDim = CollectiveMainloop::kHeadDim;

static constexpr uint32_t NBlockRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 56;
} else {
return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
}
}();
static constexpr uint32_t LoadRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 32;
} else {
return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
}
}();
static constexpr uint32_t MmaRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 224;
} else {
return NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);
}
}();

// If you want to print from the producer warp, you'd need to increase the number of registers
// Otherwise you'll get CUDA error.
Expand Down Expand Up @@ -272,7 +296,7 @@ class FlashAttnFwdSm90 {
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));

if (warp_group_idx == 0 && warp_idx_in_warpgroup != 0) { // n_block generator
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
cutlass::arch::warpgroup_reg_dealloc<NBlockRegisterRequirement>();
cutlass::PipelineState<CollectiveMainloop::kNBlockStages> n_block_pipe_write = cutlass::make_producer_start_state<MainloopPipelineNBlock>();
// Manually specify the scheduler role: producer. For StaticPersistentTileSch, passing template args won't change the behavior
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
Expand Down Expand Up @@ -556,4 +580,4 @@ class FlashAttnFwdSm90 {

};

} // namespace flash
} // namespace flash
7 changes: 4 additions & 3 deletions csrc/flashmask_v2/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
return {64, 64, false, true};
}
if (headdim <= 64) {
bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512
// bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512
// return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim};
// With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why
// https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131
// Switch to tile size 192 x 192 for now
bool const use_blockN_128 = is_causal || is_local;
// bool const use_blockN_128 = is_causal || is_local;
// return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim};
return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim};
// return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim};
// Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen
// return {192, is_causal || is_local ? 192 : 176, true, false};
return {128, 128, true, true};
} else if (headdim <= 96) {
return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true};
} else if (headdim <= 128) {
Expand Down