Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Commit f067e58

Browse files
authored
Update BLAS2 GER operator (#505)
This patch introduces a new implementation for the BLAS2 GER operator.
1 parent e5f9738 commit f067e58

File tree

5 files changed

+334
-66
lines changed

5 files changed

+334
-66
lines changed

include/interface/blas2_interface.h

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -176,34 +176,38 @@ typename sb_handle_t::event_t _symv(
176176
);
177177

178178
/*!
179-
@brief Generalised vector product followed by a sum with a rectangular
180-
non-symmetric matrix.
181-
182-
Generalised vector product followed by a sum with a rectangular non-symmetric
183-
matrix, i.e. computing the mathematical operation:
184-
185-
A = alpha*x*yT + A
186-
187-
See the netlib blas interface documentation for more details of the high level
188-
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
189-
179+
* @brief Generalised vector product followed by a sum with a rectangular
180+
* non-symmetric matrix.
181+
*
182+
* Generalised vector product followed by a sum with a rectangular non-symmetric
183+
* matrix, i.e. computing the mathematical operation:
184+
*
185+
* A = alpha*x*yT + A
186+
*
187+
* See the netlib blas interface documentation for more details of the high
188+
* level interface:
189+
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
190+
*
191+
* @param sb_handle SB_handle
192+
* @param _M Number of rows in matrix A
193+
* @param _N Number of columns in matrix A
194+
* @param _alpha Scalar alpha
195+
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
196+
* @param _incx Increment for vector X
197+
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
198+
* @param _incy Increment for vector Y
199+
* @param _mA Input/output matrix A(_lda, n)
200+
* @param _lda Leading dimension of A
201+
* @param _dependencies Vector of events
190202
*/
191203
template <typename sb_handle_t, typename index_t, typename element_t,
192204
typename container_0_t, typename increment_t, typename container_1_t,
193205
typename container_2_t>
194206
typename sb_handle_t::event_t _ger(
195-
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
196-
index_t _M, // The rows in matrix A
197-
index_t _N, // The cols of matrix A
198-
element_t _alpha, // Scalar alpha
199-
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
200-
increment_t _incx, // Increment for vector X
201-
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
202-
increment_t _incy, // Increment for vector Y
203-
container_2_t _mA, // (_lda, n) array containing A, the output
204-
index_t _lda, // >max(1, m), Leading dimension of A
205-
const typename sb_handle_t::event_t& _dependencies // Vector of events
206-
);
207+
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
208+
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
209+
container_2_t _mA, index_t _lda,
210+
const typename sb_handle_t::event_t& _dependencies);
207211

208212
/*!
209213
@brief Generalised vector squaring followed by a sum with a symmetric matrix.
@@ -746,35 +750,39 @@ typename sb_handle_t::event_t inline _symv(
746750
}
747751

748752
/*!
749-
@brief Generalised vector product followed by a sum with a rectangular
750-
non-symmetric matrix.
751-
752-
Generalised vector product followed by a sum with a rectangular non-symmetric
753-
matrix, i.e.
754-
computing the mathematical operation:
755-
756-
A = alpha*x*yT + A
757-
758-
See the netlib blas interface documentation for more details of the high level
759-
interface: http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
760-
753+
* @brief Generalised vector product followed by a sum with a rectangular
754+
* non-symmetric matrix.
755+
*
756+
* Generalised vector product followed by a sum with a rectangular non-symmetric
757+
* matrix, i.e.
758+
* computing the mathematical operation:
759+
*
760+
* A = alpha*x*yT + A
761+
*
762+
* See the netlib blas interface documentation for more details of the high
763+
* level interface:
764+
* http://www.netlib.org/lapack/explore-html/db/d5c/sger_8f.html
765+
*
766+
* @param sb_handle SB_handle
767+
* @param _M Number of rows in matrix A
768+
* @param _N Number of columns in matrix A
769+
* @param _alpha Scalar alpha
770+
* @param _vx Input vector having (1 + (_M-1)*abs(_incx)) elements
771+
* @param _incx Increment for vector X
772+
* @param _vy, Input vector having having (1 + (_N-1)*abs(_incy)) elements
773+
* @param _incy Increment for vector Y
774+
* @param _mA Input/output matrix A(_lda, n)
775+
* @param _lda Leading dimension of A
776+
* @param _dependencies Vector of events
761777
*/
762778
template <typename sb_handle_t, typename index_t, typename element_t,
763779
typename container_0_t, typename increment_t, typename container_1_t,
764780
typename container_2_t>
765781
typename sb_handle_t::event_t inline _ger(
766-
sb_handle_t& sb_handle, // sb_handle_t (sycl, parallel, serial, etc)
767-
index_t _M, // The rows in matrix M
768-
index_t _N, // The rows of matrix N
769-
element_t _alpha, // Scalar alpha
770-
container_0_t _vx, // >(1 + (_M-1)*abs(_incx)), input vector X
771-
increment_t _incx, // Increment for vector X
772-
container_1_t _vy, // >(1 + (_N-1)*abs(_incy)), input vector Y
773-
increment_t _incy, // Increment for vector Y
774-
container_2_t _mA, // (_lda, n) array containing A, the output
775-
index_t _lda, // >max(1, m), Leading dimension of A
776-
const typename sb_handle_t::event_t& _dependencies = {} // Vector of events
777-
) {
782+
sb_handle_t& sb_handle, index_t _M, index_t _N, element_t _alpha,
783+
container_0_t _vx, increment_t _incx, container_1_t _vy, increment_t _incy,
784+
container_2_t _mA, index_t _lda,
785+
const typename sb_handle_t::event_t& _dependencies = {}) {
778786
return internal::_ger(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA,
779787
_lda, _dependencies);
780788
}

include/operations/blas2_trees.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,64 @@ make_trsv(vector_t &lhs_, matrix_t &matrix_, sync_t &sync_) {
502502
subgroups, is_upper, is_transposed, is_unit>(lhs_, matrix_, k_,
503503
sync_);
504504
}
505+
/**
506+
* @struct Ger
507+
* @brief Tree node representing the sum of scalar-vector-vector product with a
508+
* matrix, i.e., it computes lhs_ such that
509+
*
510+
* lhs_ = scalar_ * ( rhs_1_ * rhs_2_^t ) + lhs_
511+
*
512+
* @param lhs_ input/output matrix
513+
* @param scalar_ value for scaling vector product
514+
* @param rhs_1_ first input vector
515+
* @param rhs_2_ second input vector
516+
* @param nRowsWG_ rows of the workgroup tile
517+
* @param nColsWG_ cols of the workgroup tile
518+
* @param nWG_row_ number of tiles per global size row
519+
* @param nWG_col_ number of tiles per global size column
520+
*
521+
*/
522+
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
523+
struct Ger {
524+
using value_t = typename rhs_2_t::value_t;
525+
using index_t = typename rhs_2_t::index_t;
526+
527+
lhs_t lhs_;
528+
value_t scalar_;
529+
rhs_1_t rhs_1_;
530+
rhs_2_t rhs_2_;
531+
index_t nRowsWG_;
532+
index_t nColsWG_;
533+
index_t nWG_row_;
534+
index_t nWG_col_;
535+
536+
Ger(lhs_t &_l, value_t _scl, rhs_1_t &_r1, rhs_2_t &_r2, index_t &_nRowsWG,
537+
index_t &_nColsWG, index_t &_nWG_row, index_t &_nWG_col);
538+
539+
index_t get_size() const;
540+
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
541+
value_t eval(index_t i);
542+
value_t eval(cl::sycl::nd_item<1> ndItem);
543+
template <typename sharedT>
544+
value_t eval(sharedT shrMem, cl::sycl::nd_item<1> ndItem);
545+
void bind(cl::sycl::handler &h);
546+
void adjust_access_displacement();
547+
};
548+
549+
/*!
550+
@brief Generator/factory for GER trees.
551+
*/
552+
template <typename lhs_t, typename rhs_1_t, typename rhs_2_t>
553+
Ger<lhs_t, rhs_1_t, rhs_2_t> make_ger(lhs_t &lhs_,
554+
typename lhs_t::value_t scalar_,
555+
rhs_1_t &rhs_1_, rhs_2_t &rhs_2_,
556+
typename rhs_2_t::index_t nRowsWG_,
557+
typename rhs_2_t::index_t nColsWG_,
558+
typename rhs_2_t::index_t nWG_row_,
559+
typename rhs_2_t::index_t nWG_col_) {
560+
return Ger<lhs_t, rhs_1_t, rhs_2_t>(lhs_, scalar_, rhs_1_, rhs_2_, nRowsWG_,
561+
nColsWG_, nWG_row_, nWG_col_);
562+
}
505563

506564
/**** GER BY ROWS M ROWS x N BLOCK USING PROPERLY THE SHARED MEMORY ****/
507565
// template <typename lhs_t,typename rhs_1_t,typename rhs_2_t>

src/interface/blas2_interface.hpp

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ typename sb_handle_t::event_t _ger_impl(
878878
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
879879
container_t2 _mA, index_t _lda,
880880
const typename sb_handle_t::event_t& _dependencies, index_t _localSize = 0,
881-
index_t _scratchPadSize = 0, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
881+
bool _useLocalMem = true, index_t _nRowsWG = 0, index_t _nColsWG = 0) {
882882
index_t M = _M;
883883
index_t N = _N;
884884
auto mA = make_matrix_view<col_major>(_mA, M, N, _lda);
@@ -887,24 +887,39 @@ typename sb_handle_t::event_t _ger_impl(
887887
typename VectorViewType<container_t1, index_t, increment_t>::type vy =
888888
make_vector_view(_vy, _incy, N);
889889

890-
const index_t localSize =
891-
(_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
892-
const index_t nRowsWG = (_nRowsWG == 0) ? localSize : std::min(M, _nRowsWG);
890+
_localSize = (_localSize == 0) ? sb_handle.get_work_group_size() : _localSize;
891+
_nRowsWG = (_nRowsWG == 0) ? _localSize : _nRowsWG;
892+
_nColsWG = (_nColsWG == 0) ? _localSize : _nColsWG;
893893

894-
const index_t nColsWG = (_nColsWG == 0) ? localSize : std::min(N, _nColsWG);
894+
assert(_localSize % _nRowsWG == 0);
895+
assert((_nRowsWG * _nColsWG) % _localSize == 0);
896+
assert(_nColsWG % (_localSize / _nRowsWG) == 0);
895897

896-
const index_t scratchPadSize =
897-
(_localSize == 0) ? localSize : _scratchPadSize;
898+
if (_useLocalMem) {
899+
assert((_nRowsWG <= _localSize) && (_nColsWG <= _localSize));
900+
} else {
901+
std::vector<size_t> subgroup_sizes =
902+
sb_handle.get_queue()
903+
.get_device()
904+
.template get_info<cl::sycl::info::device::sub_group_sizes>();
905+
size_t min_subgroup_size = *subgroup_sizes.begin();
906+
size_t max_subgroup_size = *subgroup_sizes.rbegin();
907+
assert(((_nRowsWG * _nColsWG) / _localSize) <= min_subgroup_size);
908+
assert(_nRowsWG % max_subgroup_size == 0);
909+
}
898910

899-
const index_t nWGPerCol = (N - 1) / nColsWG + 1;
900-
const index_t nWGPerRow = (M - 1) / nRowsWG + 1;
901-
const index_t globalSize = localSize * nWGPerRow * nWGPerCol;
911+
const index_t nWGPerCol = (N - 1) / _nColsWG + 1;
912+
const index_t nWGPerRow = (M - 1) / _nRowsWG + 1;
913+
const index_t globalSize = _localSize * nWGPerRow * nWGPerCol;
902914

903915
typename sb_handle_t::event_t ret;
904916
auto assignOp =
905-
make_ger_col(mA, _alpha, vx, vy, nWGPerRow, nWGPerCol, scratchPadSize);
906-
return sb_handle.execute(assignOp, localSize, globalSize, scratchPadSize,
907-
_dependencies);
917+
make_ger(mA, _alpha, vx, vy, _nRowsWG, _nColsWG, nWGPerRow, nWGPerCol);
918+
919+
return _useLocalMem ? sb_handle.execute(assignOp, _localSize, globalSize,
920+
_nRowsWG + _nColsWG, _dependencies)
921+
: sb_handle.execute(assignOp, _localSize, globalSize,
922+
_dependencies);
908923
}
909924

910925
/*! _SYR.
@@ -1280,10 +1295,30 @@ typename sb_handle_t::event_t inline _ger(
12801295
container_t0 _vx, increment_t _incx, container_t1 _vy, increment_t _incy,
12811296
container_t2 _mA, index_t _lda,
12821297
const typename sb_handle_t::event_t& _dependencies) {
1283-
// TODO: Here we can use some heuristics to select localn global, local, and
1284-
// scratch size per device
1298+
index_t localSize = 0;
1299+
bool useLocalMem = true;
1300+
index_t nRowsWG = 0;
1301+
index_t nColsWG = 0;
1302+
1303+
#if defined(INTEL_GPU)
1304+
localSize = 32;
1305+
useLocalMem = false;
1306+
nRowsWG = 32;
1307+
nColsWG = 8;
1308+
#elif defined(NVIDIA_GPU)
1309+
localSize = 256;
1310+
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
1311+
nRowsWG = 32;
1312+
nColsWG = 32;
1313+
#elif defined(AMD_GPU)
1314+
localSize = (_N < 8192 && _M < 8192) ? 512 : 256;
1315+
useLocalMem = (_N < 8192 && _M < 8192) ? false : true;
1316+
nRowsWG = (_N < 8192 && _M < 8192) ? 64 : 128;
1317+
nColsWG = (_N < 8192 && _M < 8192) ? 64 : 256;
1318+
#endif
1319+
12851320
return _ger_impl(sb_handle, _M, _N, _alpha, _vx, _incx, _vy, _incy, _mA, _lda,
1286-
_dependencies);
1321+
_dependencies, localSize, useLocalMem, nRowsWG, nColsWG);
12871322
}
12881323

12891324
template <typename sb_handle_t, typename index_t, typename element_t,

0 commit comments

Comments
 (0)