Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit 1cad09d

Browse files
Fixed minor issue in Gemm & updated Gemm doc (#456)
This PR fixes a type issue in make_gemm signature & updates the Gemm documentation after the changes made to introduce _gemm_strided_batched.
1 parent 48a5361 commit 1cad09d

File tree

3 files changed

+41
-26
lines changed

3 files changed

+41
-26
lines changed

doc/Gemm.md

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -191,34 +191,41 @@ namespace blas {
191191
/*!
192192
* @brief Wrapper around Gemm. Creates the views, then makes and launches Gemm
193193
*/
194-
template <int WgSize, bool DoubleBuffer, bool ConflictA, bool ConflictB,
195-
int ClSize, typename TileT, bool TransA, bool TransB,
196-
int GemmMemoryType, int GemmAlgorithm, int GemmVectorization,
197-
bool is_beta_zero, int VectorSize, int BatchType>
194+
template <int WgSize, bool DoubleBuffer, bool ConflictA, bool ConflictB,
195+
int ClSize, typename TileT, bool TransA, bool TransB, bool SymmA,
196+
bool SymmB, int GemmMemoryType, int GemmAlgorithm,
197+
int GemmVectorization, bool is_beta_zero, int VectorSize,
198+
int BatchType, bool UseJointMatrix>
198199

199-
template <typename SB_Handle, typename container_t0, typename container_t1,
200+
template <typename sb_handle_t, typename container_t0, typename container_t1,
200201
typename container_t2, typename element_t, typename index_t>
201202

202-
typename SB_Handle::event_t Gemm_Launcher<
203-
WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA, TransB,
204-
GemmMemoryType, GemmAlgorithm, GemmVectorization, is_beta_zero, VectorSize,
205-
BatchType>::_select_gemm(SB_Handle& sb_handle, index_t _M, index_t _N, index_t _K,
206-
element_t _alpha, container_t0 a_, index_t _lda,
207-
container_t1 b_, index_t _ldb, element_t _beta,
208-
container_t2 _C, index_t _ldc,
209-
index_t batch_size) {
203+
typename sb_handle_t::event_t
204+
Gemm_Launcher<WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA,
205+
TransB, SymmA, SymmB, GemmMemoryType, GemmAlgorithm,
206+
GemmVectorization, is_beta_zero, VectorSize, BatchType,
207+
UseJointMatrix>::_select_gemm(sb_handle_t& sb_handle, index_t _M,
208+
index_t _N, index_t _K,
209+
element_t _alpha, container_t0 a_,
210+
index_t _lda, index_t _stridea,
211+
container_t1 b_, index_t _ldb,
212+
index_t _strideb, element_t _beta,
213+
container_t2 _C, index_t _ldc,
214+
index_t _stridec,
215+
index_t batch_size) {
210216

211217
//Helper functions used to make matrix views
212-
auto buffer_a = make_matrix_view<col_major>(a_, _M, _K, _lda);
213-
auto buffer_b = make_matrix_view<col_major>(b_, _K, _N, _ldb);
218+
auto buffer_a = make_matrix_view<col_major>(a_, _M, _K, _lda);
219+
auto buffer_b = make_matrix_view<col_major>(b_, _K, _N, _ldb);
214220
auto buffer_c = make_matrix_view<col_major>(_C, _M, _N, _ldc);
215221

216222
//Helper function to construct the Gemm object
217-
auto gemm = make_gemm<DoubleBuffer, ConflictA, ConflictB, ClSize, TileT,
218-
TransA, TransB, GemmMemoryType, GemmAlgorithm,
219-
GemmVectorization, is_beta_zero, VectorSize, BatchType>(
223+
auto gemm = make_gemm<DoubleBuffer, ConflictA, ConflictB, ClSize, TileT,
224+
TransA, TransB, SymmA, SymmB, GemmMemoryType,
225+
GemmAlgorithm, GemmVectorization, is_beta_zero,
226+
VectorSize, BatchType, UseJointMatrix>(
220227
buffer_a, buffer_b, buffer_c, element_t(_alpha), element_t(_beta),
221-
batch_size);
228+
batch_size, index_t(_stridea), index_t(_strideb), index_t(_stridec));
222229

223230
//Execute the gemm and return the associated event
224231
return sb_handle.execute(gemm);
@@ -259,6 +266,14 @@ template typename SB_Handle::event_t _gemm_batched(
259266
${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_, ${INDEX_TYPE} _ldb,
260267
${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta, ${container_t2} _C, ${INDEX_TYPE} _ldc,
261268
${INDEX_TYPE} _stridec, ${INDEX_TYPE} batch_size, gemm_batch_type_t batch_type);
269+
// strided batched gemm
270+
template typename SB_Handle::event_t _gemm_strided_batched(
271+
SB_Handle& sb_handle, char _TransA, char _TransB, ${INDEX_TYPE} _M,
272+
${INDEX_TYPE} _N, ${INDEX_TYPE} _K, ${DATA_TYPE} _alpha, ${container_t0} a_,
273+
${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_,
274+
${INDEX_TYPE} _ldb, ${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta,
275+
${container_t2} _C, ${INDEX_TYPE} _ldc, ${INDEX_TYPE} _stridec,
276+
${INDEX_TYPE} batch_size);
262277
} // namespace internal
263278
} // namespace blas
264279
```
@@ -300,9 +315,10 @@ template <bool _t_a, bool _t_b, bool is_beta_zero, typename sb_handle_t,
300315
typename container_2_t, typename element_t, typename index_t>
301316

302317
typename sb_handle_t::event_t _gemm(
303-
sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha,
304-
container_0_t _a, index_t _lda, container_1_t _b, index_t _ldb,
305-
element_t _beta, container_2_t _c, index_t _ldc, index_t batch_size,
318+
sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
319+
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
320+
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
321+
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
306322
gemm_batch_type_t batch_type) {
307323
if (batch_type == gemm_batch_type_t::interleaved) {
308324
return blas::Gemm_Launcher<

include/operations/blas3_trees.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ inline Gemm<input_t, output_t, DoubleBuffer, ConflictA, ConflictB, ClSize,
267267
GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize,
268268
BatchType, UseJointMatrix>
269269
make_gemm(input_t buffer_a, input_t buffer_b, output_t buffer_c,
270-
element_t alpha, element_t beta, index_t batch_size,
271-
element_t _stridea, element_t _strideb, element_t _stridec) {
270+
element_t alpha, element_t beta, index_t batch_size, index_t _stridea,
271+
index_t _strideb, index_t _stridec) {
272272
return Gemm<input_t, output_t, DoubleBuffer, ConflictA, ConflictB, ClSize,
273273
TileType, TransA, TransB, SymmA, SymmB, element_t, is_beta_zero,
274274
GemmMemoryType, GemmAlgorithm, GemmVectorization, VectorSize,

src/interface/gemm_launcher.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ Gemm_Launcher<WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA,
6363
GemmAlgorithm, GemmVectorization, is_beta_zero,
6464
VectorSize, BatchType, UseJointMatrix>(
6565
buffer_a, buffer_b, buffer_c, element_t(_alpha), element_t(_beta),
66-
batch_size, element_t(_stridea), element_t(_strideb),
67-
element_t(_stridec));
66+
batch_size, index_t(_stridea), index_t(_strideb), index_t(_stridec));
6867
return sb_handle.execute(gemm);
6968
}
7069

0 commit comments

Comments
 (0)