-
Notifications
You must be signed in to change notification settings - Fork 225
Description
The invocations can be converted to compile time features at 2 places:
raft/cpp/include/raft/matrix/detail/linewise_op.cuh
Lines 767 to 773 in cc165d9
if (alongLines) return matrixLinewiseVecRows<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>( out, in, lineLen, nLines, op, stream, vecs...); else return matrixLinewiseVecCols<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>( out, in, lineLen, nLines, op, stream, vecs...); } raft/cpp/include/raft/matrix/detail/linewise_op.cuh
Lines 799 to 815 in cc165d9
if (alongLines) return matrixLinewiseVecRowsSpan<Type, IdxType, LayoutPolicy, VecBytes, BlockSize, Lambda, Vecs...>(out, in, lineLen, nLines, op, stream, vecs...); else return matrixLinewiseVecColsSpan<Type, IdxType, LayoutPolicy, VecBytes, BlockSize, Lambda, Vecs...>(out, in, lineLen, nLines, op, stream, vecs...); }
This struct has callers in:
raft/matrix/linewise_op.cuh
detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t>(out.data_handle(), raft/cpp/include/raft/matrix/linewise_op.cuh
Line 120 in cc165d9
detail::MatrixLinewiseOp<16, 256>::runPadded<m_t, idx_t>(out,
raft/matrix/matrix.cuh
raft/cpp/include/raft/matrix/matrix.cuh
Line 304 in cc165d9
detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t, Lambda, Vecs...>(
Also, the change will proliferate to several other primitives that use the above 2. For example, in raft/linalg/matrix_vector_op.cuh, there is the opportunity to convert both rowMajor and bcastAlongRows to template parameters. Once the original API changes, the results will cascade down to other caller sites as well.
raft/cpp/include/raft/linalg/detail/matrix_vector_op.cuh
Lines 39 to 55 in cc165d9
| bool along_lines = rowMajor == bcastAlongRows; | |
| if (rowMajor) { | |
| matrix::linewise_op<MatT, IdxType, row_major, Lambda>( | |
| handle, | |
| make_device_matrix_view<const MatT, IdxType, row_major>(matrix, N, D), | |
| make_device_matrix_view<MatT, IdxType, row_major>(out, N, D), | |
| along_lines, | |
| op, | |
| make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D)); | |
| } else { | |
| matrix::linewise_op<MatT, IdxType, col_major, Lambda>( | |
| handle, | |
| make_device_matrix_view<const MatT, IdxType, col_major>(matrix, N, D), | |
| make_device_matrix_view<MatT, IdxType, col_major>(out, N, D), | |
| along_lines, | |
| op, | |
| make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D)); |