Skip to content

Commit 6809d29

Browse files
IwakuraReinchristian-pinto
authored andcommitted
Sm100 blockwise fp8 swap ab (vllm-project#18564)
1 parent 3ae735d commit 6809d29

File tree

3 files changed

+140
-84
lines changed

3 files changed

+140
-84
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
99
torch::Tensor const& b,
1010
torch::Tensor const& a_scales,
1111
torch::Tensor const& b_scales) {
12-
TORCH_CHECK(
13-
a.size(0) % 4 == 0,
14-
"Input tensor must have a number of rows that is a multiple of 4. ",
15-
"but got: ", a.size(0), " rows.");
1612
if (out.dtype() == torch::kBFloat16) {
1713
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
1814
out, a, b, a_scales, b_scales);

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh

Lines changed: 140 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "cuda_utils.h"
34
#include "cutlass/cutlass.h"
45
#include "cutlass/numeric_types.h"
56

@@ -22,49 +23,49 @@ namespace vllm {
2223

2324
using namespace cute;
2425

25-
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
26-
class ClusterShape, typename EpilogueScheduler,
27-
typename MainloopScheduler>
26+
// clang-format off
27+
template <class OutType, int ScaleGranularityM,
28+
int ScaleGranularityN, int ScaleGranularityK,
29+
class MmaTileShape, class ClusterShape,
30+
class EpilogueScheduler, class MainloopScheduler,
31+
bool swap_ab_ = false>
2832
struct cutlass_3x_gemm_fp8_blockwise {
33+
static constexpr bool swap_ab = swap_ab_;
2934
using ElementAB = cutlass::float_e4m3_t;
3035

3136
using ElementA = ElementAB;
3237
using LayoutA = cutlass::layout::RowMajor;
38+
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
3339
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
3440

3541
using ElementB = ElementAB;
3642
using LayoutB = cutlass::layout::ColumnMajor;
43+
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
3744
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
3845

39-
using ElementC = void;
4046
using ElementD = OutType;
4147
using LayoutD = cutlass::layout::RowMajor;
48+
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
4249
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
4350

51+
using ElementC = void; // TODO: support bias
4452
using LayoutC = LayoutD;
53+
using LayoutC_Transpose = LayoutD_Transpose;
4554
static constexpr int AlignmentC = AlignmentD;
4655

4756
using ElementAccumulator = float;
4857
using ElementCompute = float;
4958
using ElementBlockScale = float;
5059

51-
// MMA and Cluster Tile Shapes
52-
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
53-
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
54-
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
55-
static constexpr int ScaleGranularityM =
56-
size<0>(MmaTileShape{}) / ScaleMsPerTile;
57-
static constexpr int ScaleGranularityN =
58-
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
59-
static constexpr int ScaleGranularityK =
60-
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
61-
62-
// Shape of the threadblocks in a cluster
63-
using ClusterShape_MNK = ClusterShape;
64-
65-
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
66-
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
67-
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
60+
using ScaleConfig = conditional_t<swap_ab,
61+
cutlass::detail::Sm100BlockwiseScaleConfig<
62+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
63+
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
64+
cutlass::detail::Sm100BlockwiseScaleConfig<
65+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
66+
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
67+
68+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
6869
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
6970
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
7071

@@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise {
7374

7475
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
7576
using ElementScalar = float;
76-
// clang-format off
7777
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
7878
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
7979
ArchTag,
@@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise {
8484
ElementAccumulator,
8585
ElementCompute,
8686
ElementC,
87-
LayoutC,
87+
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
8888
AlignmentC,
8989
ElementD,
90-
LayoutD,
90+
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
9191
AlignmentD,
9292
EpilogueScheduler,
9393
DefaultOperation
9494
>::CollectiveOp;
9595

9696
using StageCountType = cutlass::gemm::collective::StageCountAuto;
97-
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
98-
ArchTag,
99-
OperatorClass,
100-
ElementA,
101-
cute::tuple<LayoutA, LayoutSFA>,
102-
AlignmentA,
103-
ElementB,
104-
cute::tuple<LayoutB, LayoutSFB>,
105-
AlignmentB,
106-
ElementAccumulator,
107-
MmaTileShape,
108-
ClusterShape,
109-
97+
using CollectiveMainloop = conditional_t<swap_ab,
98+
typename cutlass::gemm::collective::CollectiveBuilder<
99+
ArchTag,
100+
OperatorClass,
101+
ElementB,
102+
cute::tuple<LayoutB_Transpose, LayoutSFA>,
103+
AlignmentB,
104+
ElementA,
105+
cute::tuple<LayoutA_Transpose, LayoutSFB>,
106+
AlignmentA,
107+
ElementAccumulator,
108+
MmaTileShape,
109+
ClusterShape,
110110
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
111-
MainloopScheduler
112-
>::CollectiveOp;
113-
// clang-format on
111+
MainloopScheduler
112+
>::CollectiveOp,
113+
typename cutlass::gemm::collective::CollectiveBuilder<
114+
ArchTag,
115+
OperatorClass,
116+
ElementA,
117+
cute::tuple<LayoutA, LayoutSFA>,
118+
AlignmentA,
119+
ElementB,
120+
cute::tuple<LayoutB, LayoutSFB>,
121+
AlignmentB,
122+
ElementAccumulator,
123+
MmaTileShape,
124+
ClusterShape,
125+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
126+
MainloopScheduler
127+
>::CollectiveOp>;
114128

115129
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
116130
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
@@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
123137
torch::Tensor const& b,
124138
torch::Tensor const& a_scales,
125139
torch::Tensor const& b_scales) {
140+
static constexpr bool swap_ab = Gemm::swap_ab;
126141
using GemmKernel = typename Gemm::GemmKernel;
127142
using StrideA = typename Gemm::GemmKernel::StrideA;
128143
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
136151
using ElementD = typename Gemm::ElementD;
137152

138153
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
139-
auto prob_shape = cute::make_shape(m, n, k, 1);
140154

141155
StrideA a_stride;
142156
StrideB b_stride;
@@ -146,21 +160,36 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
146160
b_stride =
147161
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
148162
c_stride =
149-
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
163+
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
150164

151-
LayoutSFA layout_SFA =
165+
LayoutSFA layout_SFA = swap_ab ?
166+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
152167
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
153-
LayoutSFB layout_SFB =
168+
LayoutSFB layout_SFB = swap_ab ?
169+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
154170
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
155171

156172
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
157173
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
158174
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
159175
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
160176

161-
typename GemmKernel::MainloopArguments mainloop_args{
162-
a_ptr, a_stride, b_ptr, b_stride,
163-
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
177+
auto mainloop_args = [&](){
178+
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
179+
if (swap_ab) {
180+
return typename GemmKernel::MainloopArguments{
181+
b_ptr, b_stride, a_ptr, a_stride,
182+
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
183+
};
184+
}
185+
else {
186+
return typename GemmKernel::MainloopArguments{
187+
a_ptr, a_stride, b_ptr, b_stride,
188+
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
189+
};
190+
}
191+
}();
192+
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
164193

165194
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
166195
typename GemmKernel::EpilogueArguments epilogue_args{
@@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
175204
torch::Tensor const& b,
176205
torch::Tensor const& a_scales,
177206
torch::Tensor const& b_scales) {
178-
auto m = a.size(0);
179-
auto k = a.size(1);
180-
auto n = b.size(1);
181-
int sms;
207+
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
182208
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
183209

184-
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
185-
return std::ceil(static_cast<float>(m) / tile1SM) *
186-
std::ceil(static_cast<float>(n) / tile1SM) >=
187-
sms;
188-
};
189-
bool use_2sm = should_use_2sm(m, n);
190-
if (use_2sm) {
191-
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
192-
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
193-
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
194-
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
195-
out, a, b, a_scales, b_scales);
210+
constexpr int TILE_K = 128;
211+
// TODO: better heuristics
212+
bool swap_ab = (m < 16) || (m % 4 != 0);
213+
bool use_tma_epilogue = (m * n) % 4 == 0;
214+
if (!swap_ab) {
215+
constexpr int TILE_N = 128;
216+
int tile_m = 256;
217+
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) {
218+
tile_m = 64;
219+
}
220+
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) {
221+
tile_m = 128;
222+
}
223+
if (tile_m == 64) {
224+
if (use_tma_epilogue) {
225+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
226+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
227+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
228+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
229+
out, a, b, a_scales, b_scales);
230+
} else {
231+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
232+
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
233+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
234+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
235+
out, a, b, a_scales, b_scales);
236+
}
237+
} else if (tile_m == 128) {
238+
if (use_tma_epilogue) {
239+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
240+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
241+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
242+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
243+
out, a, b, a_scales, b_scales);
244+
} else {
245+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
246+
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
247+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
248+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
249+
out, a, b, a_scales, b_scales);
250+
}
251+
} else { // tile_m == 256
252+
if (use_tma_epilogue) {
253+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
254+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
255+
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
256+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
257+
out, a, b, a_scales, b_scales);
258+
} else {
259+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
260+
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
261+
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
262+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
263+
out, a, b, a_scales, b_scales);
264+
}
265+
}
196266
} else {
267+
// TODO: Test more tile N configs
268+
constexpr int TILE_M = 128;
269+
constexpr int TILE_N = 16;
270+
// TMA epilogue isn't compatible with Swap A/B
197271
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
198-
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
199-
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
200-
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
272+
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
273+
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
274+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
201275
out, a, b, a_scales, b_scales);
202276
}
203277
}

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,10 @@ def ceil_div(x: int, y: int) -> int:
136136
use_cutlass, use_aiter_and_is_supported)
137137

138138
if use_cutlass:
139-
rows, cols = input_2d.shape
140-
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
141-
# optimal tensor core usage. Can be removed when targeting platforms
142-
# without this constraint.
143-
should_pad = current_platform.has_device_capability(
144-
100) and rows % 4 != 0
145-
if should_pad:
146-
input_2d = torch.nn.functional.pad(input_2d,
147-
(0, 0, 0, 4 - (rows % 4)),
148-
value=0).contiguous()
149-
150139
q_input, x_scale = per_token_group_quant_fp8(
151140
input_2d, block_size[1], column_major_scales=use_cutlass)
152-
153141
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
154142
block_size, input.dtype)
155-
if should_pad:
156-
output = output[:rows, :]
157143

158144
else:
159145
q_input, x_scale = per_token_group_quant_fp8(

0 commit comments

Comments
 (0)