Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions csrc/flash_attn_v3/epilogue_bwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ struct CollectiveEpilogueBwdGQA {
};
using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;

using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
using ShapedKV = cute::Shape<int64_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
using StridedKV = cute::Stride<_1, int64_t, int64_t>;

// Host side kernel arguments
Expand Down Expand Up @@ -429,9 +429,10 @@ struct CollectiveEpilogueBwdGQA {
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
bool const is_varlen = Varlen && params.cu_seqlens;
Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);

Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
Tensor gdKaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
Tensor gdVaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)

R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
Expand Down
22 changes: 11 additions & 11 deletions csrc/flash_attn_v3/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
static_cast<float*>(params.softmax_lse_log2_ptr),
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
static_cast<ElementAccum*>(params.dq_accum_ptr),
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
{int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum
{_1{}, int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, !is_varlen_q ? int64_t{params.d_rounded} * int64_t{seqlen_q_rounded} * int64_t{params.h} : 0}, // stride_dQaccum
params.b,
params.dq_semaphore,
params.cu_seqlens_q,
Expand Down Expand Up @@ -114,7 +114,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
{seqlen_q, params.dv, params.h, batch_q}, // shape_dO
{params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
static_cast<ElementAccum*>(params.dq_accum_ptr),
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
{int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
static_cast<float*>(params.softmax_lse_log2_ptr),
{seqlen_q_rounded, params.h, batch_q}, // shape_LSE
Expand All @@ -136,29 +136,29 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
if constexpr (!GQA) {
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
} else {
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, params.h_k, batch_k}; // shape_dKaccum
}
}(),
[&] {
if constexpr (!GQA) {
return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
} else {
return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.d_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dKaccum
}
}(),
static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
[&] {
if constexpr (!GQA) {
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV
} else {
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum
return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}; // shape_dVaccum
}
}(),
[&] {
if constexpr (!GQA) {
return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
} else {
return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.dv_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dVaccum
}
}(),
params.h,
Expand Down Expand Up @@ -225,7 +225,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
>;
typename PostprocessKernel::Arguments postprocess_args {
static_cast<ElementAccum const*>(params.dq_accum_ptr),
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
{int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
static_cast<Element*>(params.dq_ptr),
{seqlen_q, params.d, params.h, batch_q}, // shape_dQ
Expand Down Expand Up @@ -254,7 +254,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
typename PostprocessKerneldKV::Arguments postprocess_dK_args {
static_cast<ElementAccum const*>(params.dk_accum_ptr),
{seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
{_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
{_1{}, int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, !is_varlen_k ? int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dKaccum
static_cast<Element*>(params.dk_ptr),
{seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
{params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
Expand All @@ -265,8 +265,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
typename PostprocessKerneldKV::Arguments postprocess_dV_args {
static_cast<ElementAccum const*>(params.dv_accum_ptr),
{seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum
{_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
{int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}, // shape_dVaccum
{_1{}, int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, !is_varlen_k ? int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dVaccum
static_cast<Element*>(params.dv_ptr),
{seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV
{params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class FlashAttnBwdPostprocessConvertdQ {

using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using ShapedQaccum = cute::Shape<int64_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;

// Device side arguments
Expand Down Expand Up @@ -174,7 +174,7 @@ class FlashAttnBwdPostprocessConvertdQ {
// Step 1: load dQaccum from gmem to smem
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
Tensor gdQaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
if constexpr (IsSm90) { // Use BulkCopy
static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
Expand Down
4 changes: 2 additions & 2 deletions csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class FlashAttnBwdPreprocess {
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using ShapedQaccum = cute::Shape<int64_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;

// Device side arguments
Expand Down Expand Up @@ -230,7 +230,7 @@ class FlashAttnBwdPreprocess {

if constexpr (Clear_dQaccum) {
Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Expand Down
6 changes: 3 additions & 3 deletions csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ struct CollectiveMainloopBwdSm90 {
using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using ShapedQaccum = cute::Shape<int64_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;

using TMA_QdO = decltype(make_tma_copy_A_sm90(
Expand Down Expand Up @@ -613,7 +613,7 @@ struct CollectiveMainloopBwdSm90 {
bool const is_varlen = Varlen && params.cu_seqlens_q;
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)

int const num_batch = params.num_batch;
Expand Down Expand Up @@ -790,7 +790,7 @@ struct CollectiveMainloopBwdSm90 {
bool const is_varlen = Varlen && params.cu_seqlens_q;
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)
// We can reuse r2s_thr_copy_dQaccum for this partitioning
Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);
Expand Down