Skip to content

Commit 69d7197

Browse files
Enable batch tests for streamK (intel#258)
This PR enables the unit tests for StreamK to use batch sizes != 1. They were disabled due to an issue with the alignment for the copy operations used in the unit test. This PR re-enables the tests by removing the alignment in the problem_size_k calculation. The alignment issue will be fixed once the changes to use the SPIR-V copy functions are complete. Co-authored-by: jiyang1011 <[email protected]>
1 parent 463a3a3 commit 69d7197

7 files changed

+20
-21
lines changed

test/unit/gemm/device/gemm_testbed_3x.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4060,7 +4060,7 @@ bool TestXe(
40604060

40614061
// Use larger K sizes for stream-K tests
40624062
static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::min_iters_per_sk_unit_;
4063-
problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment};
4063+
problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit};
40644064
}
40654065

40664066
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
@@ -4076,7 +4076,7 @@ bool TestXe(
40764076
for (auto raster_order : raster_orders) {
40774077
for (auto max_swizzle_size : max_swizzle_sizes) {
40784078
for (DecompositionMode decomp_mode : decomposition_modes) {
4079-
std::vector problem_splits = {detail::Splits{1}};
4079+
std::vector problem_splits = {detail::Splits{1}};
40804080
if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) {
40814081
auto max_splits = (k + TileShapeK - 1) / TileShapeK;
40824082
if (max_splits > 2) {

test/unit/gemm/device/gemm_universal_s8t_bf16n_f32t_mixed_input_tensor_op_f32_xe.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ TEST(XE_Device_GemmUniversal_s8t_bf16n_f32t_mixed_input_tensor_op_f32, 128x128x6
131131

132132
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
133133

134-
// TODO(Codeplay): gemm batch doesn't work for mixed type
135-
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 1.0);
134+
bool passed = test::gemm::device::TestXe<Gemm>();
136135
EXPECT_TRUE(passed);
137136
}
138137
////////////////////////////////////////////////////////////////////////////////

test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_cooperative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,28 @@ TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_f32_cooperative, 256x256x32) {
7676
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_f32_cooperative<
7777
layout::RowMajor, layout::RowMajor>::Gemm;
7878
// TODO(Codeplay): Enable batch tests
79-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
79+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8080
}
8181

8282
TEST(XE_Device_Gemm_bf16n_bf16t_f32t_tensor_op_f32_cooperative, 256x256x32) {
8383
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_f32_cooperative<
8484
layout::ColumnMajor, layout::RowMajor>::Gemm;
8585
// TODO(Codeplay): Enable batch tests
86-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
86+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8787
}
8888

8989
TEST(XE_Device_Gemm_bf16t_bf16n_f32t_tensor_op_f32_cooperative, 256x256x32) {
9090
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_f32_cooperative<
9191
layout::RowMajor, layout::ColumnMajor>::Gemm;
9292
// TODO(Codeplay): Enable batch tests
93-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
93+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9494
}
9595

9696
TEST(XE_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_cooperative, 256x256x32) {
9797
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_f32_cooperative<
9898
layout::ColumnMajor, layout::ColumnMajor>::Gemm;
9999
// TODO(Codeplay): Enable batch tests
100-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
100+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
101101
}
102102
}
103103
} // namespace cutlass

test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ TEST(XE_Device_Gemm_bf16t_bf16t_f32_tensor_op_gmma_f32_epilogue, 256x256x32_LinC
164164

165165
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_gmma_f32_epilogue<CollectiveEpilogue>::Gemm;
166166

167-
bool passed = test::gemm::device::TestXe<Gemm>(1.0, 0.0);
167+
bool passed = test::gemm::device::TestXe<Gemm>();
168168
EXPECT_TRUE(passed);
169169
}
170170

test/unit/gemm/device/xe_gemm_fp16_fp16_fp32_tensor_op_fp32_cooperative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,31 +78,31 @@ TEST(XE_Device_Gemm_fp16t_fp16t_f32t_tensor_op_f32_cooperative, 256x256x32) {
7878
using LayoutB = layout::RowMajor;
7979
using Gemm = XE_Device_Gemm_fp16_fp16_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
8080
// TODO(Codeplay): Enable batch tests
81-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
81+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8282
}
8383

8484
TEST(XE_Device_Gemm_fp16n_fp16t_f32t_tensor_op_f32_cooperative, 256x256x32) {
8585
using LayoutA = layout::ColumnMajor;
8686
using LayoutB = layout::RowMajor;
8787
using Gemm = XE_Device_Gemm_fp16_fp16_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
8888
// TODO(Codeplay): Enable batch tests
89-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
89+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9090
}
9191

9292
TEST(XE_Device_Gemm_fp16t_fp16n_f32t_tensor_op_f32_cooperative, 256x256x32) {
9393
using LayoutA = layout::RowMajor;
9494
using LayoutB = layout::ColumnMajor;
9595
using Gemm = XE_Device_Gemm_fp16_fp16_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
9696
// TODO(Codeplay): Enable batch tests
97-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
97+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9898
}
9999

100100
TEST(XE_Device_Gemm_fp16n_fp16n_f32t_tensor_op_f32_cooperative, 256x256x32) {
101101
using LayoutA = layout::ColumnMajor;
102102
using LayoutB = layout::ColumnMajor;
103103
using Gemm = XE_Device_Gemm_fp16_fp16_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
104104
// TODO(Codeplay): Enable batch tests
105-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
105+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
106106
}
107107
}
108108
} // namespace cutlass

test/unit/gemm/device/xe_gemm_s8_s8_s32_tensor_op_s32_cooperative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ TEST(XE_Device_Gemm_s8t_s8t_s32t_tensor_op_s32_cooperative, 64x128x32) {
7777
using LayoutB = layout::RowMajor;
7878
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32_cooperative<LayoutA, LayoutB>::Gemm;
7979
// TODO(Codeplay): Enable batch tests
80-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
80+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8181
}
8282

8383
/* TODO(Codeplay): Transposed copy are not implemented
@@ -86,23 +86,23 @@ TEST(XE_Device_Gemm_s8n_s8t_s32t_tensor_op_s32_cooperative, 64x128x32) {
8686
using LayoutB = layout::RowMajor;
8787
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32_cooperative<LayoutA, LayoutB>::Gemm;
8888
// TODO(Codeplay): Enable batch tests
89-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
89+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9090
}
9191
9292
TEST(XE_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_cooperative, 64x128x32) {
9393
using LayoutA = layout::RowMajor;
9494
using LayoutB = layout::ColumnMajor;
9595
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32_cooperative<LayoutA, LayoutB>::Gemm;
9696
// TODO(Codeplay): Enable batch tests
97-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
97+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9898
}
9999
100100
TEST(XE_Device_Gemm_s8n_s8n_s32t_tensor_op_s32_cooperative, 64x128x32) {
101101
using LayoutA = layout::ColumnMajor;
102102
using LayoutB = layout::ColumnMajor;
103103
using Gemm = XE_Device_Gemm_s8_s8_s32_tensor_op_s32_cooperative<LayoutA, LayoutB>::Gemm;
104104
// TODO(Codeplay): Enable batch tests
105-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
105+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
106106
}
107107
*/
108108
}

test/unit/gemm/device/xe_gemm_tf32_tf32_fp32_tensor_op_fp32_cooperative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ TEST(XE_Device_Gemm_tf32t_tf32t_f32t_tensor_op_f32_cooperative, 256x256x32) {
7979
using LayoutB = layout::RowMajor;
8080
using Gemm = XE_Device_Gemm_tf32_tf32_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
8181
// TODO(Codeplay): Enable batch tests
82-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
82+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8383
}
8484

8585
/* TODO(Codeplay): missing copy transpose builtin and prefetch builtin
@@ -88,23 +88,23 @@ TEST(XE_Device_Gemm_tf32n_tf32t_f32t_tensor_op_f32_cooperative, 256x256x32) {
8888
using LayoutB = layout::RowMajor;
8989
using Gemm = XE_Device_Gemm_tf32_tf32_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
9090
// TODO(Codeplay): Enable batch tests
91-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
91+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
9292
}
9393
9494
TEST(XE_Device_Gemm_tf32t_tf32n_f32t_tensor_op_f32_cooperative, 256x256x32) {
9595
using LayoutA = layout::RowMajor;
9696
using LayoutB = layout::ColumnMajor;
9797
using Gemm = XE_Device_Gemm_tf32_tf32_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
9898
// TODO(Codeplay): Enable batch tests
99-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
99+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
100100
}
101101
102102
TEST(XE_Device_Gemm_tf32n_tf32n_f32t_tensor_op_f32_cooperative, 256x256x32) {
103103
using LayoutA = layout::ColumnMajor;
104104
using LayoutB = layout::ColumnMajor;
105105
using Gemm = XE_Device_Gemm_tf32_tf32_f32_tensor_op_f32_cooperative<LayoutA, LayoutB>::Gemm;
106106
// TODO(Codeplay): Enable batch tests
107-
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>(1.0, 0.0, false));
107+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
108108
}
109109
*/
110110
}

0 commit comments

Comments
 (0)