Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit 2c6b203

Browse files
authored
Update gemv tuning target parameters (#449)
This patch adjusts the gemv tuning target parameters and removes unnecessary headers for the generation of matrix-vector multiplication routines.
1 parent 0355a58 commit 2c6b203

File tree

6 files changed

+30
-68
lines changed

6 files changed

+30
-68
lines changed

src/interface/blas2/backend/amd_gpu.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
3737
index_t _lda, container_t1 _vx,
3838
increment_t _incx, element_t _beta,
3939
container_t2 _vy, increment_t _incy) {
40-
static constexpr uint32_t cache_line_size = 256;
40+
static constexpr uint32_t cache_line_size = 128;
4141
if (trn == transpose_type::Normal) {
4242
return blas::internal::_gemv_impl<256, cache_line_size,
4343
gemv_memory_t::local, trn>(
4444
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
4545
} else {
46-
return blas::internal::_gemv_impl<64, cache_line_size, gemv_memory_t::local,
47-
trn>(sb_handle, _M, _N, _alpha, _mA, _lda,
48-
_vx, _incx, _beta, _vy, _incy);
46+
return blas::internal::_gemv_impl<128, cache_line_size,
47+
gemv_memory_t::local, trn>(
48+
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
4949
}
5050
}
5151
} // namespace backend

src/interface/blas2/backend/intel_gpu.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,18 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
3838
increment_t _incx, element_t _beta,
3939
container_t2 _vy, increment_t _incy) {
4040
if (trn == transpose_type::Normal) {
41-
return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>(
42-
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
41+
if (_N < 8192) {
42+
return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>(
43+
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
44+
} else if (_N < 16384) {
45+
return blas::internal::_gemv_impl<256, 64, gemv_memory_t::local, trn>(
46+
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
47+
} else {
48+
return blas::internal::_gemv_impl<512, 64, gemv_memory_t::local, trn>(
49+
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
50+
}
4351
} else {
44-
return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>(
52+
return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>(
4553
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
4654
}
4755
}

src/interface/blas2/backend/nvidia_gpu.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
3838
increment_t _incx, element_t _beta,
3939
container_t2 _vy, increment_t _incy) {
4040
if (trn == transpose_type::Normal) {
41-
return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>(
41+
return blas::internal::_gemv_impl<256, 128, gemv_memory_t::local, trn>(
4242
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
4343
} else {
44-
return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>(
44+
return blas::internal::_gemv_impl<128, 128, gemv_memory_t::local, trn>(
4545
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
4646
}
4747
}

src/interface/blas2/gemv.cpp.in

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,49 +22,13 @@
2222
* @filename gemv.cpp.in
2323
*
2424
**************************************************************************/
25-
#include "container/sycl_iterator.hpp"
26-
#include "sb_handle/sycl_blas_handle.hpp"
27-
#include "sb_handle/kernel_constructor.hpp"
2825
#include "interface/blas2_interface.hpp"
29-
#include "operations/blas1_trees.hpp"
30-
#include "operations/blas2_trees.hpp"
31-
#include "operations/blas_constants.hpp"
32-
#include "views/view_sycl.hpp"
26+
#include "sb_handle/kernel_constructor.hpp"
27+
#include "sb_handle/sycl_blas_handle.hpp"
3328

3429
namespace blas {
3530
namespace internal {
3631

37-
/*!
38-
@brief Generalised matrix vector product with rectangular non-symmetric
39-
matrices.
40-
41-
Generalised matrix vector product with rectangular non-symmetric matrices, i.e.
42-
computing the mathematical operation:
43-
44-
y = alpha*A*x + beta*y
45-
46-
See the netlib blas interface documentation for more details of the high level
47-
interface: http://www.netlib.org/lapack/explore-html/db/d58/sgemv_8f.html
48-
SB_Handle& sb_handle, // SB_Handle (sycl, parallel, serial, etc)
49-
char _trans, // The transposition of the matrix ('n', 't', 'c')
50-
index_t _M, // The size of dimension M of the matrix (rows)
51-
index_t _N, // The size of dimension N of the matrix (columns)
52-
element_t _alpha, // Scalar parameter Alpha
53-
container_t0 _mA, // An array (LDA,N), with the first m*n elements
54-
index_t _lda, // Specifies the first dimension of a, max(1, m)
55-
container_t1 _vx, // An array of dimension at least:
56-
(1+(n-1)*abs(incx))
57-
// when trans = 'n' and (1+(m-1)*abs(incx) otherwise,
58-
// containing the vector "x"
59-
increment_t _incx, // The increment for elements in x (nonzero).
60-
element_t _beta, // Scalar parameter Beta
61-
container_t2 _vy, // An array of dimension at least:
62-
(1+(m-1)*abs(incy))
63-
// when trans = "n" and (1+(n-1)*abs(incy) otherwise,
64-
// containing the vector "y" (if beta is nonzero). When
65-
// finished, y is overwritten with the updated vector.
66-
increment_t _incy // The increment for elements in y (nonzero).
67-
*/
6832
template typename SB_Handle::event_t _gemv(
6933
SB_Handle& sb_handle, char _trans, ${INDEX_TYPE} _M, ${INDEX_TYPE} _N,
7034
${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda,

src/interface/blas2/symv.cpp.in

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,18 @@
2222
* @filename symv.cpp.in
2323
*
2424
**************************************************************************/
25-
#include "container/sycl_iterator.hpp"
26-
#include "sb_handle/sycl_blas_handle.hpp"
27-
#include "sb_handle/kernel_constructor.hpp"
2825
#include "interface/blas2_interface.hpp"
29-
#include "operations/blas1_trees.hpp"
30-
#include "operations/blas2_trees.hpp"
31-
#include "operations/blas_constants.hpp"
32-
#include "views/view_sycl.hpp"
26+
#include "sb_handle/kernel_constructor.hpp"
27+
#include "sb_handle/sycl_blas_handle.hpp"
3328

3429
namespace blas {
3530
namespace internal {
3631

3732
template typename SB_Handle::event_t _symv(
38-
SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N,
39-
${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda,
40-
${container_t1} _vx, ${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta,
41-
${container_t2} _vy, ${INCREMENT_TYPE} _incy);
33+
SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N, ${DATA_TYPE} _alpha,
34+
${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx,
35+
${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta, ${container_t2} _vy,
36+
${INCREMENT_TYPE} _incy);
4237

4338
} // namespace internal
4439
} // namespace blas

src/interface/blas2/trmv.cpp.in

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,16 @@
2222
* @filename trmv.cpp.in
2323
*
2424
**************************************************************************/
25-
#include "container/sycl_iterator.hpp"
26-
#include "sb_handle/sycl_blas_handle.hpp"
27-
#include "sb_handle/kernel_constructor.hpp"
2825
#include "interface/blas2_interface.hpp"
29-
#include "operations/blas1_trees.hpp"
30-
#include "operations/blas2_trees.hpp"
31-
#include "operations/blas_constants.hpp"
32-
#include "views/view_sycl.hpp"
26+
#include "sb_handle/kernel_constructor.hpp"
27+
#include "sb_handle/sycl_blas_handle.hpp"
3328

3429
namespace blas {
3530
namespace internal {
3631

3732
template typename SB_Handle::event_t _trmv(
38-
SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag,
39-
${INDEX_TYPE} _N, ${container_t0} _mA, ${INDEX_TYPE} _lda,
40-
${container_t1} _vx, ${INCREMENT_TYPE} _incx);
33+
SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag, ${INDEX_TYPE} _N,
34+
${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx,
35+
${INCREMENT_TYPE} _incx);
4136
} // namespace internal
4237
} // end namespace blas

0 commit comments

Comments
 (0)