@@ -191,34 +191,41 @@ namespace blas {
191
191
/* !
192
192
* @brief Wrapper around Gemm. Creates the views, then makes and launches Gemm
193
193
* /
194
- template <int WgSize, bool DoubleBuffer, bool ConflictA, bool ConflictB,
195
- int ClSize, typename TileT, bool TransA, bool TransB,
196
- int GemmMemoryType, int GemmAlgorithm, int GemmVectorization,
197
- bool is_beta_zero, int VectorSize, int BatchType>
194
+ template <int WgSize, bool DoubleBuffer, bool ConflictA, bool ConflictB,
195
+ int ClSize, typename TileT, bool TransA, bool TransB, bool SymmA,
196
+ bool SymmB, int GemmMemoryType, int GemmAlgorithm,
197
+ int GemmVectorization, bool is_beta_zero, int VectorSize,
198
+ int BatchType, bool UseJointMatrix>
198
199
199
- template <typename SB_Handle , typename container_t0, typename container_t1,
200
+ template <typename sb_handle_t , typename container_t0, typename container_t1,
200
201
typename container_t2, typename element_t, typename index_t>
201
202
202
- typename SB_Handle::event_t Gemm_Launcher<
203
- WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA, TransB,
204
- GemmMemoryType, GemmAlgorithm, GemmVectorization, is_beta_zero, VectorSize,
205
- BatchType>::_ select_gemm(SB_Handle& sb_handle, index_t _ M, index_t _ N, index_t _ K,
206
- element_t _ alpha, container_t0 a_ , index_t _ lda,
207
- container_t1 b_ , index_t _ ldb, element_t _ beta,
208
- container_t2 _ C, index_t _ ldc,
209
- index_t batch_size) {
203
+ typename sb_handle_t::event_t
204
+ Gemm_Launcher<WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA,
205
+ TransB, SymmA, SymmB, GemmMemoryType, GemmAlgorithm,
206
+ GemmVectorization, is_beta_zero, VectorSize, BatchType,
207
+ UseJointMatrix>::_ select_gemm(sb_handle_t& sb_handle, index_t _ M,
208
+ index_t _ N, index_t _ K,
209
+ element_t _ alpha, container_t0 a_ ,
210
+ index_t _ lda, index_t _ stridea,
211
+ container_t1 b_ , index_t _ ldb,
212
+ index_t _ strideb, element_t _ beta,
213
+ container_t2 _ C, index_t _ ldc,
214
+ index_t _ stridec,
215
+ index_t batch_size) {
210
216
211
217
//Helper functions used to make matrix views
212
- auto buffer_a = make_matrix_view<col_major>(a_ , _ M, _ K, _ lda);
213
- auto buffer_b = make_matrix_view<col_major>(b_ , _ K, _ N, _ ldb);
218
+ auto buffer_a = make_matrix_view<col_major>(a_ , _ M, _ K, _ lda);
219
+ auto buffer_b = make_matrix_view<col_major>(b_ , _ K, _ N, _ ldb);
214
220
auto buffer_c = make_matrix_view<col_major>(_ C, _ M, _ N, _ ldc);
215
221
216
222
//Helper function to construct the Gemm object
217
- auto gemm = make_gemm<DoubleBuffer, ConflictA, ConflictB, ClSize, TileT,
218
- TransA, TransB, GemmMemoryType, GemmAlgorithm,
219
- GemmVectorization, is_beta_zero, VectorSize, BatchType>(
223
+ auto gemm = make_gemm<DoubleBuffer, ConflictA, ConflictB, ClSize, TileT,
224
+ TransA, TransB, SymmA, SymmB, GemmMemoryType,
225
+ GemmAlgorithm, GemmVectorization, is_beta_zero,
226
+ VectorSize, BatchType, UseJointMatrix>(
220
227
buffer_a, buffer_b, buffer_c, element_t(_ alpha), element_t(_ beta),
221
- batch_size);
228
+ batch_size, index_t( _ stridea), index_t( _ strideb), index_t( _ stridec) );
222
229
223
230
//Execute the gemm and return the associated event
224
231
return sb_handle.execute(gemm);
@@ -259,6 +266,14 @@ template typename SB_Handle::event_t _gemm_batched(
259
266
${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_, ${INDEX_TYPE} _ldb,
260
267
${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta, ${container_t2} _C, ${INDEX_TYPE} _ldc,
261
268
${INDEX_TYPE} _stridec, ${INDEX_TYPE} batch_size, gemm_batch_type_t batch_type);
269
+ // strided batched gemm
270
+ template typename SB_Handle::event_t _gemm_strided_batched(
271
+ SB_Handle& sb_handle, char _TransA, char _TransB, ${INDEX_TYPE} _M,
272
+ ${INDEX_TYPE} _N, ${INDEX_TYPE} _K, ${DATA_TYPE} _alpha, ${container_t0} a_,
273
+ ${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_,
274
+ ${INDEX_TYPE} _ldb, ${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta,
275
+ ${container_t2} _C, ${INDEX_TYPE} _ldc, ${INDEX_TYPE} _stridec,
276
+ ${INDEX_TYPE} batch_size);
262
277
} // namespace internal
263
278
} // namespace blas
264
279
```
@@ -300,9 +315,10 @@ template <bool _t_a, bool _t_b, bool is_beta_zero, typename sb_handle_t,
300
315
typename container_2_t , typename element_t , typename index_t >
301
316
302
317
typename sb_handle_t ::event_t _gemm (
303
- sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha,
304
- container_0_t _a, index_t _lda, container_1_t _b, index_t _ldb,
305
- element_t _beta, container_2_t _c, index_t _ldc, index_t batch_size,
318
+ sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
319
+ element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
320
+ container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
321
+ container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
306
322
gemm_batch_type_t batch_type) {
307
323
if (batch_type == gemm_batch_type_t::interleaved) {
308
324
return blas::Gemm_Launcher<
0 commit comments