Skip to content

Commit 859acbd

Browse files
divyegalabkarsin
authored andcommitted
Instantiate only specific RAFT reduction kernels (rapidsai#925)
Depends on rapidsai/raft#2679 with reference issue rapidsai/raft#2681. This PR reduces the wheel size from 936.9 MB to 917.1 MB. Authors: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Robert Maynard (https://github.com/robertmaynard) - Corey J. Nolet (https://github.com/cjnolet) - Mike Sarahan (https://github.com/msarahan) - https://github.com/jakirkham URL: rapidsai#925
1 parent 3c569b0 commit 859acbd

File tree

25 files changed

+237
-320
lines changed

25 files changed

+237
-320
lines changed

build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ if (( NUMARGS == 0 )) || hasArg libcuvs || hasArg docs || hasArg tests || hasArg
450450
MSG="${MSG}<br/>parallel setting: $PARALLEL_LEVEL"
451451
MSG="${MSG}<br/>parallel build time: $compile_total seconds"
452452
if [[ -f "${LIBCUVS_BUILD_DIR}/libcuvs.so" ]]; then
453-
LIBCUVS_FS=$(ls -lh ${LIBCUVS_BUILD_DIR}/libcuvs.so | awk '{print $5}')
453+
LIBCUVS_FS=$(stat -c %s ${LIBCUVS_BUILD_DIR}/libcuvs.so | awk '{printf "%.2f MB", $1/1024/1024}')
454454
MSG="${MSG}<br/>libcuvs.so size: $LIBCUVS_FS"
455455
fi
456456
BMR_DIR=${RAPIDS_ARTIFACTS_DIR:-"${LIBCUVS_BUILD_DIR}"}

cpp/cmake/patches/faiss.diff

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
diff --git a/faiss/gpu/impl/CuvsIVFFlat.cu b/faiss/gpu/impl/CuvsIVFFlat.cu
2+
index 2cccee860..c4cb13f48 100644
3+
--- a/faiss/gpu/impl/CuvsIVFFlat.cu
4+
+++ b/faiss/gpu/impl/CuvsIVFFlat.cu
5+
@@ -427,13 +427,11 @@ void CuvsIVFFlat::copyInvertedListsFrom(const InvertedLists* ivf) {
6+
// Precompute the centers vector norms for L2Expanded distance
7+
if (this->metric_ == faiss::METRIC_L2) {
8+
cuvs_index->allocate_center_norms(raft_handle);
9+
- raft::linalg::rowNorm(
10+
+ raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
11+
cuvs_index->center_norms().value().data_handle(),
12+
cuvs_index->centers().data_handle(),
13+
cuvs_index->dim(),
14+
(uint32_t)nlist,
15+
- raft::linalg::L2Norm,
16+
- true,
17+
raft_handle.get_stream());
18+
}
19+
}

cpp/cmake/patches/faiss_override.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
"faiss" : {
44
"version": "1.10.0",
55
"git_url": "https://github.com/facebookresearch/faiss.git",
6-
"git_tag": "main"
6+
"git_tag": "main",
7+
"patches" : [
8+
{
9+
"file" : "${current_json_dir}/faiss.diff",
10+
"issue" : "Apply RAFT breaking changes",
11+
"fixed_in" : ""
12+
}
13+
]
714
}
815
}
916
}

cpp/src/cluster/detail/kmeans.cuh

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,8 @@ void kmeansPlusPlus(raft::resources const& handle,
143143

144144
if (metric == cuvs::distance::DistanceType::L2Expanded ||
145145
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
146-
raft::linalg::rowNorm(L2NormX.data_handle(),
147-
X.data_handle(),
148-
X.extent(1),
149-
X.extent(0),
150-
raft::linalg::L2Norm,
151-
true,
152-
stream);
146+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
147+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
153148
}
154149

155150
raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);
@@ -216,14 +211,12 @@ void kmeansPlusPlus(raft::resources const& handle,
216211

217212
// Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using
218213
// centroid candidate-i
219-
raft::linalg::reduce(costPerCandidate.data_handle(),
220-
minDistBuf.data_handle(),
221-
minDistBuf.extent(1),
222-
minDistBuf.extent(0),
223-
static_cast<DataT>(0),
224-
true,
225-
true,
226-
stream);
214+
raft::linalg::reduce<true, true>(costPerCandidate.data_handle(),
215+
minDistBuf.data_handle(),
216+
minDistBuf.extent(1),
217+
minDistBuf.extent(0),
218+
static_cast<DataT>(0),
219+
stream);
227220

228221
// Greedy Choice - Choose the candidate that has minimum cluster cost
229222
// ArgMin operation below identifies the index of minimum cost in costPerCandidate
@@ -406,13 +399,8 @@ void kmeans_fit_main(raft::resources const& handle,
406399

407400
if (metric == cuvs::distance::DistanceType::L2Expanded ||
408401
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
409-
raft::linalg::rowNorm(L2NormX.data_handle(),
410-
X.data_handle(),
411-
X.extent(1),
412-
X.extent(0),
413-
raft::linalg::L2Norm,
414-
true,
415-
stream);
402+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
403+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
416404
}
417405

418406
RAFT_LOG_DEBUG(
@@ -634,13 +622,8 @@ void initScalableKMeansPlusPlus(raft::resources const& handle,
634622
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
635623
if (metric == cuvs::distance::DistanceType::L2Expanded ||
636624
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
637-
raft::linalg::rowNorm(L2NormX.data_handle(),
638-
X.data_handle(),
639-
X.extent(1),
640-
X.extent(0),
641-
raft::linalg::L2Norm,
642-
true,
643-
stream);
625+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
626+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
644627
}
645628

646629
auto minClusterDistanceVec = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
@@ -1049,13 +1032,8 @@ void kmeans_predict(raft::resources const& handle,
10491032
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
10501033
if (metric == cuvs::distance::DistanceType::L2Expanded ||
10511034
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
1052-
raft::linalg::rowNorm(L2NormX.data_handle(),
1053-
X.data_handle(),
1054-
X.extent(1),
1055-
X.extent(0),
1056-
raft::linalg::L2Norm,
1057-
true,
1058-
stream);
1035+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
1036+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
10591037
}
10601038

10611039
// computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i]

cpp/src/cluster/detail/kmeans_balanced.cuh

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
113113

114114
auto centroidsNorm =
115115
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
116-
raft::linalg::rowNorm<MathT, IdxT>(
117-
centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream);
116+
raft::linalg::rowNorm<raft::linalg::L2Norm, true, MathT, IdxT>(
117+
centroidsNorm.data_handle(), centers, dim, n_clusters, stream);
118118

119119
cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
120120
minClusterAndDistance.data_handle(),
@@ -156,14 +156,8 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
156156

157157
auto centroidsNorm =
158158
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
159-
raft::linalg::rowNorm<MathT, IdxT>(centroidsNorm.data_handle(),
160-
centers,
161-
dim,
162-
n_clusters,
163-
raft::linalg::L2Norm,
164-
true,
165-
stream,
166-
raft::sqrt_op{});
159+
raft::linalg::rowNorm<raft::linalg::L2Norm, true, MathT, IdxT>(
160+
centroidsNorm.data_handle(), centers, dim, n_clusters, stream, raft::sqrt_op{});
167161

168162
cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
169163
minClusterAndDistance.data_handle(),
@@ -395,8 +389,8 @@ void compute_norm(const raft::resources& handle,
395389
dataset_ptr = static_cast<const MathT*>(mapped_dataset.data());
396390
}
397391

398-
raft::linalg::rowNorm<MathT, IdxT>(
399-
dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream, norm_fin_op);
392+
raft::linalg::rowNorm<raft::linalg::L2Norm, true, MathT, IdxT>(
393+
dataset_norm, dataset_ptr, dim, n_rows, stream, norm_fin_op);
400394
}
401395

402396
/**
@@ -732,8 +726,8 @@ void balancing_em_iters(const raft::resources& handle,
732726
cluster_centers, n_clusters, dim);
733727
auto clusters_out_view = raft::make_device_matrix_view<MathT, IdxT, raft::row_major>(
734728
cluster_centers, n_clusters, dim);
735-
raft::linalg::row_normalize(
736-
handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm);
729+
raft::linalg::row_normalize<raft::linalg::L2Norm>(
730+
handle, clusters_in_view, clusters_out_view);
737731
break;
738732
}
739733
default: break;

cpp/src/cluster/detail/kmeans_common.cuh

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,11 @@ void minClusterAndDistanceCompute(
382382

383383
if (is_fused) {
384384
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
385-
raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(),
386-
centroids.data_handle(),
387-
centroids.extent(1),
388-
centroids.extent(0),
389-
raft::linalg::L2Norm,
390-
true,
391-
stream);
385+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(L2NormBuf_OR_DistBuf.data(),
386+
centroids.data_handle(),
387+
centroids.extent(1),
388+
centroids.extent(0),
389+
stream);
392390
} else {
393391
// TODO: Unless pool allocator is used, passing in a workspace for this
394392
// isn't really increasing performance because this needs to do a re-allocation
@@ -518,13 +516,11 @@ void minClusterDistanceCompute(raft::resources const& handle,
518516

519517
if (is_fused) {
520518
L2NormBuf_OR_DistBuf.resize(n_clusters, stream);
521-
raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(),
522-
centroids.data_handle(),
523-
centroids.extent(1),
524-
centroids.extent(0),
525-
raft::linalg::L2Norm,
526-
true,
527-
stream);
519+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(L2NormBuf_OR_DistBuf.data(),
520+
centroids.data_handle(),
521+
centroids.extent(1),
522+
centroids.extent(0),
523+
stream);
528524
} else {
529525
L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream);
530526
}

cpp/src/cluster/detail/kmeans_mg.cuh

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,8 @@ void initKMeansPlusPlus(const raft::resources& handle,
228228
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
229229
if (metric == cuvs::distance::DistanceType::L2Expanded ||
230230
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
231-
raft::linalg::rowNorm(L2NormX.data_handle(),
232-
X.data_handle(),
233-
X.extent(1),
234-
X.extent(0),
235-
raft::linalg::L2Norm,
236-
true,
237-
stream);
231+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
232+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
238233
}
239234

240235
auto minClusterDistance = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
@@ -578,13 +573,8 @@ void fit(const raft::resources& handle,
578573
auto L2NormX = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
579574
if (metric == cuvs::distance::DistanceType::L2Expanded ||
580575
metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
581-
raft::linalg::rowNorm(L2NormX.data_handle(),
582-
X.data_handle(),
583-
X.extent(1),
584-
X.extent(0),
585-
raft::linalg::L2Norm,
586-
true,
587-
stream);
576+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
577+
L2NormX.data_handle(), X.data_handle(), X.extent(1), X.extent(0), stream);
588578
}
589579

590580
DataT priorClusteringCost = 0;

cpp/src/cluster/kmeans.cuh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,8 @@ void cluster_cost(raft::resources const& handle,
481481

482482
auto x_norms = raft::make_device_vector<DataT>(handle, n_samples);
483483

484-
raft::linalg::rowNorm(x_norms.data_handle(),
485-
X.data_handle(),
486-
n_features,
487-
n_samples,
488-
raft::linalg::L2Norm,
489-
true,
490-
stream);
484+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
485+
x_norms.data_handle(), X.data_handle(), n_features, n_samples, stream);
491486

492487
auto min_cluster_distance = raft::make_device_vector<DataT>(handle, n_samples);
493488
rmm::device_uvector<DataT> l2_norm_or_distance_buffer(0, stream);

cpp/src/distance/detail/distance.cuh

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -141,50 +141,34 @@ void distance_impl(raft::resources const& handle,
141141
// perhaps the use of stridedSummationKernel could be causing this,
142142
// need to investigate and fix.
143143
if (x == y && is_row_major) {
144-
raft::linalg::reduce(x_norm,
145-
x,
146-
k,
147-
std::max(m, n),
148-
(AccT)0,
149-
is_row_major,
150-
true,
151-
stream,
152-
false,
153-
raft::identity_op(),
154-
raft::add_op());
144+
raft::linalg::reduce<true, true>(
145+
x_norm, x, k, std::max(m, n), (AccT)0, stream, false, raft::identity_op(), raft::add_op());
155146
sq_x_norm += std::max(m, n);
156147
sq_y_norm = sq_x_norm;
157-
raft::linalg::rowNorm(
158-
sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream);
148+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(sq_x_norm, x, k, std::max(m, n), stream);
159149
} else {
160150
y_norm += m;
161-
raft::linalg::reduce(x_norm,
162-
x,
163-
k,
164-
m,
165-
(AccT)0,
166-
is_row_major,
167-
true,
168-
stream,
169-
false,
170-
raft::identity_op(),
171-
raft::add_op());
172-
raft::linalg::reduce(y_norm,
173-
y,
174-
k,
175-
n,
176-
(AccT)0,
177-
is_row_major,
178-
true,
179-
stream,
180-
false,
181-
raft::identity_op(),
182-
raft::add_op());
151+
if (is_row_major) {
152+
raft::linalg::reduce<true, true>(
153+
x_norm, x, k, m, (AccT)0, stream, false, raft::identity_op(), raft::add_op());
154+
raft::linalg::reduce<true, true>(
155+
y_norm, y, k, n, (AccT)0, stream, false, raft::identity_op(), raft::add_op());
156+
} else {
157+
raft::linalg::reduce<false, true>(
158+
x_norm, x, k, m, (AccT)0, stream, false, raft::identity_op(), raft::add_op());
159+
raft::linalg::reduce<false, true>(
160+
y_norm, y, k, n, (AccT)0, stream, false, raft::identity_op(), raft::add_op());
161+
}
183162

184163
sq_x_norm += (m + n);
185164
sq_y_norm = sq_x_norm + m;
186-
raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
187-
raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream);
165+
if (is_row_major) {
166+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(sq_x_norm, x, k, m, stream);
167+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(sq_y_norm, y, k, n, stream);
168+
} else {
169+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(sq_x_norm, x, k, m, stream);
170+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(sq_y_norm, y, k, n, stream);
171+
}
188172
}
189173

190174
using OpT = ops::correlation_distance_op<DataT, AccT, IdxT>;
@@ -224,14 +208,17 @@ void distance_impl(raft::resources const& handle,
224208
// perhaps the use of stridedSummationKernel could be causing this,
225209
// need to investigate and fix.
226210
if (x == y && is_row_major) {
227-
raft::linalg::rowNorm(
228-
x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
211+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
212+
x_norm, x, k, std::max(m, n), stream, raft::sqrt_op{});
229213
} else {
230214
y_norm += m;
231-
raft::linalg::rowNorm(
232-
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
233-
raft::linalg::rowNorm(
234-
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
215+
if (is_row_major) {
216+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(x_norm, x, k, m, stream, raft::sqrt_op{});
217+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(y_norm, y, k, n, stream, raft::sqrt_op{});
218+
} else {
219+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(x_norm, x, k, m, stream, raft::sqrt_op{});
220+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(y_norm, y, k, n, stream, raft::sqrt_op{});
221+
}
235222
}
236223

237224
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
@@ -482,20 +469,21 @@ void distance_impl_l2_expanded( // NOTE: different name
482469
// perhaps the use of stridedSummationKernel could be causing this,
483470
// need to investigate and fix.
484471
if ((x == y) && is_row_major) {
485-
raft::linalg::rowNorm(x_norm,
486-
x,
487-
k,
488-
std::max(m, n),
489-
raft::linalg::L2Norm,
490-
is_row_major,
491-
stream,
492-
raft::identity_op{});
472+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
473+
x_norm, x, k, std::max(m, n), stream, raft::identity_op{});
493474
} else {
494475
y_norm += m;
495-
raft::linalg::rowNorm(
496-
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
497-
raft::linalg::rowNorm(
498-
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
476+
if (is_row_major) {
477+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
478+
x_norm, x, k, m, stream, raft::identity_op{});
479+
raft::linalg::rowNorm<raft::linalg::L2Norm, true>(
480+
y_norm, y, k, n, stream, raft::identity_op{});
481+
} else {
482+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(
483+
x_norm, x, k, m, stream, raft::identity_op{});
484+
raft::linalg::rowNorm<raft::linalg::L2Norm, false>(
485+
y_norm, y, k, n, stream, raft::identity_op{});
486+
}
499487
}
500488

501489
ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{perform_sqrt};

0 commit comments

Comments
 (0)