@@ -141,50 +141,34 @@ void distance_impl(raft::resources const& handle,
141141 // perhaps the use of stridedSummationKernel could be causing this,
142142 // need to investigate and fix.
143143 if (x == y && is_row_major) {
144- raft::linalg::reduce (x_norm,
145- x,
146- k,
147- std::max (m, n),
148- (AccT)0 ,
149- is_row_major,
150- true ,
151- stream,
152- false ,
153- raft::identity_op (),
154- raft::add_op ());
144+ raft::linalg::reduce<true , true >(
145+ x_norm, x, k, std::max (m, n), (AccT)0 , stream, false , raft::identity_op (), raft::add_op ());
155146 sq_x_norm += std::max (m, n);
156147 sq_y_norm = sq_x_norm;
157- raft::linalg::rowNorm (
158- sq_x_norm, x, k, std::max (m, n), raft::linalg::L2Norm, is_row_major, stream);
148+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(sq_x_norm, x, k, std::max (m, n), stream);
159149 } else {
160150 y_norm += m;
161- raft::linalg::reduce (x_norm,
162- x,
163- k,
164- m,
165- (AccT)0 ,
166- is_row_major,
167- true ,
168- stream,
169- false ,
170- raft::identity_op (),
171- raft::add_op ());
172- raft::linalg::reduce (y_norm,
173- y,
174- k,
175- n,
176- (AccT)0 ,
177- is_row_major,
178- true ,
179- stream,
180- false ,
181- raft::identity_op (),
182- raft::add_op ());
151+ if (is_row_major) {
152+ raft::linalg::reduce<true , true >(
153+ x_norm, x, k, m, (AccT)0 , stream, false , raft::identity_op (), raft::add_op ());
154+ raft::linalg::reduce<true , true >(
155+ y_norm, y, k, n, (AccT)0 , stream, false , raft::identity_op (), raft::add_op ());
156+ } else {
157+ raft::linalg::reduce<false , true >(
158+ x_norm, x, k, m, (AccT)0 , stream, false , raft::identity_op (), raft::add_op ());
159+ raft::linalg::reduce<false , true >(
160+ y_norm, y, k, n, (AccT)0 , stream, false , raft::identity_op (), raft::add_op ());
161+ }
183162
184163 sq_x_norm += (m + n);
185164 sq_y_norm = sq_x_norm + m;
186- raft::linalg::rowNorm (sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
187- raft::linalg::rowNorm (sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream);
165+ if (is_row_major) {
166+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(sq_x_norm, x, k, m, stream);
167+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(sq_y_norm, y, k, n, stream);
168+ } else {
169+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(sq_x_norm, x, k, m, stream);
170+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(sq_y_norm, y, k, n, stream);
171+ }
188172 }
189173
190174 using OpT = ops::correlation_distance_op<DataT, AccT, IdxT>;
@@ -224,14 +208,17 @@ void distance_impl(raft::resources const& handle,
224208 // perhaps the use of stridedSummationKernel could be causing this,
225209 // need to investigate and fix.
226210 if (x == y && is_row_major) {
227- raft::linalg::rowNorm (
228- x_norm, x, k, std::max (m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
211+ raft::linalg::rowNorm<raft::linalg::L2Norm, true > (
212+ x_norm, x, k, std::max (m, n), stream, raft::sqrt_op{});
229213 } else {
230214 y_norm += m;
231- raft::linalg::rowNorm (
232- x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
233- raft::linalg::rowNorm (
234- y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
215+ if (is_row_major) {
216+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(x_norm, x, k, m, stream, raft::sqrt_op{});
217+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(y_norm, y, k, n, stream, raft::sqrt_op{});
218+ } else {
219+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(x_norm, x, k, m, stream, raft::sqrt_op{});
220+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(y_norm, y, k, n, stream, raft::sqrt_op{});
221+ }
235222 }
236223
237224 ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
@@ -482,20 +469,21 @@ void distance_impl_l2_expanded( // NOTE: different name
482469 // perhaps the use of stridedSummationKernel could be causing this,
483470 // need to investigate and fix.
484471 if ((x == y) && is_row_major) {
485- raft::linalg::rowNorm (x_norm,
486- x,
487- k,
488- std::max (m, n),
489- raft::linalg::L2Norm,
490- is_row_major,
491- stream,
492- raft::identity_op{});
472+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(
473+ x_norm, x, k, std::max (m, n), stream, raft::identity_op{});
493474 } else {
494475 y_norm += m;
495- raft::linalg::rowNorm (
496- x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
497- raft::linalg::rowNorm (
498- y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
476+ if (is_row_major) {
477+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(
478+ x_norm, x, k, m, stream, raft::identity_op{});
479+ raft::linalg::rowNorm<raft::linalg::L2Norm, true >(
480+ y_norm, y, k, n, stream, raft::identity_op{});
481+ } else {
482+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(
483+ x_norm, x, k, m, stream, raft::identity_op{});
484+ raft::linalg::rowNorm<raft::linalg::L2Norm, false >(
485+ y_norm, y, k, n, stream, raft::identity_op{});
486+ }
499487 }
500488
501489 ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{perform_sqrt};
0 commit comments