diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ac58d0af1f..c7986ce477 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -391,9 +391,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Is_attn_mask) { - flash::apply_attn_mask(scores, tPgMask, tPcMask, - m_block == m_block_max - 1 ? m_residue : params.seqlen_q, - n_block == n_block_max - 1 ? n_residue : params.seqlen_k, + flash::apply_attn_mask(scores, tPgMask, tPcMask, + params.seqlen_q, + params.seqlen_k, params.unscale_softmax); tPgMask.data() = tPgMask.data() + (-kBlockN); } @@ -519,8 +519,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi if (Is_attn_mask) { flash::apply_attn_mask(scores, tPgMask, tPcMask, - m_block == m_block_max - 1 ? m_residue : params.seqlen_q, - n_block == n_block_max - 1 ? n_residue : params.seqlen_k, + params.seqlen_q, + params.seqlen_k, params.unscale_softmax); tPgMask.data() = tPgMask.data() + (-kBlockN); }