Skip to content

Sm100 blockwise fp8 swap ab #18564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(
a.size(0) % 4 == 0,
"Input tensor must have a number of rows that is a multiple of 4. ",
"but got: ", a.size(0), " rows.");
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"

Expand All @@ -22,49 +23,49 @@ namespace vllm {

using namespace cute;

template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
class ClusterShape, typename EpilogueScheduler,
typename MainloopScheduler>
// clang-format off
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler,
bool swap_ab_ = false>
struct cutlass_3x_gemm_fp8_blockwise {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = cutlass::float_e4m3_t;

using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using ElementC = void;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

using ElementC = void; // TODO: support bias
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
static constexpr int AlignmentC = AlignmentD;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;

// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM =
size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN =
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK =
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});

// Shape of the threadblocks in a cluster
using ClusterShape_MNK = ClusterShape;

using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
using ScaleConfig = conditional_t<swap_ab,
cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;

// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());

Expand All @@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {

static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
// clang-format off
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
Expand All @@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
AlignmentC,
ElementD,
LayoutD,
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;

using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,

using CollectiveMainloop = conditional_t<swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementB,
cute::tuple<LayoutB_Transpose, LayoutSFA>,
AlignmentB,
ElementA,
cute::tuple<LayoutA_Transpose, LayoutSFB>,
AlignmentA,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
// clang-format on
MainloopScheduler
>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp>;

using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
Expand All @@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
Expand All @@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
using ElementD = typename Gemm::ElementD;

int32_t m = a.size(0), n = b.size(1), k = a.size(1);
auto prob_shape = cute::make_shape(m, n, k, 1);

StrideA a_stride;
StrideB b_stride;
Expand All @@ -146,21 +160,36 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));

LayoutSFA layout_SFA =
LayoutSFA layout_SFA = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
LayoutSFB layout_SFB = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));

auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());

typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
auto mainloop_args = [&](){
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
if (swap_ab) {
return typename GemmKernel::MainloopArguments{
b_ptr, b_stride, a_ptr, a_stride,
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
};
}
else {
return typename GemmKernel::MainloopArguments{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
};
}
}();
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);

auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Expand All @@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto m = a.size(0);
auto k = a.size(1);
auto n = b.size(1);
int sms;
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
return std::ceil(static_cast<float>(m) / tile1SM) *
std::ceil(static_cast<float>(n) / tile1SM) >=
sms;
};
bool use_2sm = should_use_2sm(m, n);
if (use_2sm) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
constexpr int TILE_K = 128;
// TODO: better heuristics
bool swap_ab = (m < 16) || (m % 4 != 0);
bool use_tma_epilogue = (m * n) % 4 == 0;
if (!swap_ab) {
constexpr int TILE_N = 128;
int tile_m = 256;
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
tile_m = 64;
}
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
tile_m = 128;
}
if (tile_m == 64) {
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
} else if (tile_m == 128) {
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
} else { // tile_m == 256
if (use_tma_epilogue) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
}
}
} else {
// TODO: Test more tile N configs
constexpr int TILE_M = 128;
constexpr int TILE_N = 16;
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
out, a, b, a_scales, b_scales);
}
}
Expand Down
14 changes: 0 additions & 14 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,10 @@ def ceil_div(x: int, y: int) -> int:
use_cutlass, use_aiter_and_is_supported)

if use_cutlass:
rows, cols = input_2d.shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad = current_platform.has_device_capability(
100) and rows % 4 != 0
if should_pad:
input_2d = torch.nn.functional.pad(input_2d,
(0, 0, 0, 4 - (rows % 4)),
value=0).contiguous()

q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)

output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype)
if should_pad:
output = output[:rows, :]

else:
q_input, x_scale = per_token_group_quant_fp8(
Expand Down
Loading