Skip to content

Remove generic kernel invocations of MatrixLinewiseOp #2682

@divyegala

Description

@divyegala

The invocations can be converted to compile time features at 2 places:

  1. if (alongLines)
    return matrixLinewiseVecRows<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>(
    out, in, lineLen, nLines, op, stream, vecs...);
    else
    return matrixLinewiseVecCols<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>(
    out, in, lineLen, nLines, op, stream, vecs...);
    }
  2. if (alongLines)
    return matrixLinewiseVecRowsSpan<Type,
    IdxType,
    LayoutPolicy,
    VecBytes,
    BlockSize,
    Lambda,
    Vecs...>(out, in, lineLen, nLines, op, stream, vecs...);
    else
    return matrixLinewiseVecColsSpan<Type,
    IdxType,
    LayoutPolicy,
    VecBytes,
    BlockSize,
    Lambda,
    Vecs...>(out, in, lineLen, nLines, op, stream, vecs...);
    }

This struct has callers in:
raft/matrix/linewise_op.cuh

  1. detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t>(out.data_handle(),
  2. detail::MatrixLinewiseOp<16, 256>::runPadded<m_t, idx_t>(out,

raft/matrix/matrix.cuh

  1. detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t, Lambda, Vecs...>(

Also, the change will proliferate to several other primitives that use the above 2. For example, in raft/linalg/matrix_vector_op.cuh, there is the opportunity to convert both rowMajor and bcastAlongRows to template parameters. Once the original API changes, the results will cascade down to other caller sites as well.

bool along_lines = rowMajor == bcastAlongRows;
if (rowMajor) {
matrix::linewise_op<MatT, IdxType, row_major, Lambda>(
handle,
make_device_matrix_view<const MatT, IdxType, row_major>(matrix, N, D),
make_device_matrix_view<MatT, IdxType, row_major>(out, N, D),
along_lines,
op,
make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D));
} else {
matrix::linewise_op<MatT, IdxType, col_major, Lambda>(
handle,
make_device_matrix_view<const MatT, IdxType, col_major>(matrix, N, D),
make_device_matrix_view<MatT, IdxType, col_major>(out, N, D),
along_lines,
op,
make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D));

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions