Skip to content

Commit c1a572d

Browse files
authored
Multi-gpu KNN build for UMAP using all-neighbors API (#6654)
Allows multi-gpu knn graph building in UMAP using the all-neighbors API. ## PRs that need to be merged before this one - rapidsai/cuvs#785 - rapidsai/raft#2666 ## Changes in cuML UMAP usage ``` from pylibraft.common import DeviceResourcesSNMG # if want to use multi GPU multigpu_handle = DeviceResourcesSNMG() umap_nnd = UMAP(handle = multigpu_handle, build_algo="nn_descent", build_kwds={"nnd_n_nearest_clusters": 2, "nnd_n_clusters": 8, "nnd_graph_degree": 32, "nnd_max_iterations": 20 }) ``` Closes #6729 Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Simon Adorf (https://github.com/csadorf) - Divye Gala (https://github.com/divyegala) - Victor Lafargue (https://github.com/viclafargue) - Jim Crist-Harif (https://github.com/jcrist) URL: #6654
1 parent a38c3d5 commit c1a572d

File tree

5 files changed

+222
-101
lines changed

5 files changed

+222
-101
lines changed

cpp/include/cuml/manifold/umapparams.h

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,55 @@
2424

2525
namespace ML {
2626

27-
using nn_index_params = cuvs::neighbors::nn_descent::index_params;
27+
namespace graph_build_params {
28+
29+
/**
30+
* Arguments for using nn descent as the knn build algorithm.
31+
* graph_degree must be larger than or equal to n_neighbors.
32+
* Increasing graph_degree and max_iterations may result in better accuracy.
33+
* Smaller termination threshold means stricter convergence criteria for nn descent and may take
34+
* longer to converge.
35+
*/
36+
struct nn_descent_params_umap {
37+
// not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
38+
// Descent parameters
39+
size_t graph_degree = 64;
40+
size_t intermediate_graph_degree = 128;
41+
size_t max_iterations = 20;
42+
float termination_threshold = 0.0001;
43+
};
44+
45+
/**
46+
* Parameters for knn graph building in UMAP.
47+
* [Hint1]: the ratio of overlap_factor / n_clusters determines device memory usage.
48+
* Approximately (overlap_factor / n_clusters) * num_rows_in_entire_data number of rows will be
49+
* put on device memory at once. E.g. between (overlap_factor / n_clusters) = 2/10 and 2/20, the
50+
* latter will use less device memory.
51+
* [Hint2]: larger overlap_factor results in better accuracy
52+
* of the final all-neighbors knn graph. E.g. While using similar amount of device memory,
53+
* (overlap_factor / n_clusters) = 4/20 will have better accuracy than 2/10 at the cost of
54+
* performance.
55+
* [Hint3]: for overlap_factor, start with 2, and gradually increase (2->3->4 ...)
56+
* for better accuracy
57+
* [Hint4]: for n_clusters, start with 4, and gradually increase(4->8->16 ...)
58+
* for less GPU memory usage. This is independent from overlap_factor as long as
59+
* overlap_factor < n_clusters
60+
*/
61+
struct graph_build_params {
62+
/**
63+
* Number of clusters each data point is assigned to. Only valid when n_clusters > 1.
64+
*/
65+
size_t overlap_factor = 2;
66+
/**
67+
* Number of clusters to split the data into when building the knn graph. Increasing this will use
68+
* less device memory at the cost of accuracy. When using n_clusters > 1, is is required that the
69+
* data is put on host (refer to data_on_host argument for fit_transform). The default value
70+
* (n_clusters=1) will place the entire data on device memory.
71+
*/
72+
size_t n_clusters = 1;
73+
nn_descent_params_umap nn_descent_params;
74+
};
75+
} // namespace graph_build_params
2876

2977
class UMAPParams {
3078
public:
@@ -149,7 +197,7 @@ class UMAPParams {
149197
*/
150198
graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;
151199

152-
nn_index_params nn_descent_params = {};
200+
graph_build_params::graph_build_params build_params;
153201

154202
/**
155203
* The number of nearest neighbors to use to construct the target simplicial

cpp/src/umap/knn_graph/algo.cuh

Lines changed: 51 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
#include <raft/util/cudart_utils.hpp>
3434

3535
#include <cuvs/distance/distance.hpp>
36+
#include <cuvs/neighbors/all_neighbors.hpp>
3637
#include <cuvs/neighbors/brute_force.hpp>
37-
#include <cuvs/neighbors/nn_descent.hpp>
3838
#include <stdint.h>
3939

4040
#include <iostream>
@@ -56,23 +56,6 @@ void launcher(const raft::handle_t& handle,
5656
const ML::UMAPParams* params,
5757
cudaStream_t stream);
5858

59-
auto get_graph_nnd(const raft::handle_t& handle,
60-
const ML::manifold_dense_inputs_t<float>& inputs,
61-
const ML::UMAPParams* params)
62-
{
63-
cudaPointerAttributes attr;
64-
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X));
65-
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
66-
if (ptr != nullptr) {
67-
auto dataset =
68-
raft::make_device_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
69-
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
70-
} else {
71-
auto dataset = raft::make_host_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
72-
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
73-
}
74-
}
75-
7659
// Instantiation for dense inputs, int64_t indices
7760
template <>
7861
inline void launcher(const raft::handle_t& handle,
@@ -83,12 +66,14 @@ inline void launcher(const raft::handle_t& handle,
8366
const ML::UMAPParams* params,
8467
cudaStream_t stream)
8568
{
69+
cudaPointerAttributes attr;
70+
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
71+
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
72+
bool data_on_device = ptr != nullptr;
73+
8674
if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
87-
cudaPointerAttributes attr;
88-
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
89-
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
90-
auto idx = [&]() {
91-
if (ptr != nullptr) { // inputsA on device
75+
auto idx = [&]() {
76+
if (data_on_device) { // inputsA on device
9277
return cuvs::neighbors::brute_force::build(
9378
handle,
9479
{params->metric, params->p},
@@ -107,45 +92,50 @@ inline void launcher(const raft::handle_t& handle,
10792
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors),
10893
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors));
10994
} else { // nn_descent
110-
// TODO: use nndescent from cuvs
111-
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
112-
"n_neighbors should be smaller than the graph degree computed by nn descent");
113-
RAFT_EXPECTS(params->nn_descent_params.return_distances,
114-
"return_distances for nn descent should be set to true to be used for UMAP");
115-
116-
auto graph = get_graph_nnd(handle, inputsA, params);
117-
118-
// `graph.graph()` is a host array (n x graph_degree).
119-
// Slice and copy to a temporary host array (n x n_neighbors), then copy
120-
// that to the output device array `out.knn_indices` (n x n_neighbors).
121-
// TODO: force graph_degree = n_neighbors so the temporary host array and
122-
// slice isn't necessary.
123-
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
124-
size_t graph_degree = params->nn_descent_params.graph_degree;
125-
#pragma omp parallel for
126-
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
127-
for (int j = 0; j < n_neighbors; j++) {
128-
auto target = temp_indices_h.data_handle();
129-
auto source = graph.graph().data_handle();
130-
target[i * n_neighbors + j] = source[i * graph_degree + j];
131-
}
95+
RAFT_EXPECTS(
96+
static_cast<size_t>(n_neighbors) <= params->build_params.nn_descent_params.graph_degree,
97+
"n_neighbors should be smaller than the graph degree computed by nn descent");
98+
RAFT_EXPECTS(
99+
params->build_params.nn_descent_params.graph_degree <=
100+
params->build_params.nn_descent_params.intermediate_graph_degree,
101+
"graph_degree should be smaller than intermediate_graph_degree computed by nn descent");
102+
103+
auto all_neighbors_params = cuvs::neighbors::all_neighbors::all_neighbors_params{};
104+
all_neighbors_params.overlap_factor = params->build_params.overlap_factor;
105+
all_neighbors_params.n_clusters = params->build_params.n_clusters;
106+
all_neighbors_params.metric = params->metric;
107+
108+
auto nn_descent_params =
109+
cuvs::neighbors::all_neighbors::graph_build_params::nn_descent_params{};
110+
nn_descent_params.graph_degree = params->build_params.nn_descent_params.graph_degree;
111+
nn_descent_params.intermediate_graph_degree =
112+
params->build_params.nn_descent_params.intermediate_graph_degree;
113+
nn_descent_params.max_iterations = params->build_params.nn_descent_params.max_iterations;
114+
nn_descent_params.termination_threshold =
115+
params->build_params.nn_descent_params.termination_threshold;
116+
nn_descent_params.metric = params->metric;
117+
all_neighbors_params.graph_build_params = nn_descent_params;
118+
119+
auto indices_view =
120+
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors);
121+
auto distances_view =
122+
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors);
123+
124+
if (data_on_device) { // inputsA on device
125+
cuvs::neighbors::all_neighbors::build(
126+
handle,
127+
all_neighbors_params,
128+
raft::make_device_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
129+
indices_view,
130+
distances_view);
131+
} else { // inputsA on host
132+
cuvs::neighbors::all_neighbors::build(
133+
handle,
134+
all_neighbors_params,
135+
raft::make_host_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
136+
indices_view,
137+
distances_view);
132138
}
133-
raft::copy(handle,
134-
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
135-
temp_indices_h.view());
136-
137-
// `graph.distances()` is a device array (n x graph_degree).
138-
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
139-
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
140-
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
141-
static_cast<int64_t>(0),
142-
static_cast<int64_t>(inputsA.n),
143-
static_cast<int64_t>(n_neighbors)};
144-
raft::matrix::slice<float, int64_t, raft::row_major>(
145-
handle,
146-
raft::make_const_mdspan(graph.distances().value()),
147-
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
148-
coords);
149139
}
150140
}
151141

python/cuml/cuml/manifold/umap.pyx

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,42 @@ class UMAP(UniversalBase,
273273
'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is
274274
smaller than or equal to 50K. Otherwise, runs with nn descent.
275275
build_kwds: dict (optional, default=None)
276-
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
277-
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
278-
'nnd_n_clusters': 1}
279-
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
276+
Dictionary of parameters to configure the build algorithm. Default values:
277+
278+
- `nnd_graph_degree` (int, default=64): Graph degree used for NN Descent.
279+
Must be ≥ `n_neighbors`.
280+
281+
- `nnd_intermediate_graph_degree` (int, default=128): Intermediate graph degree for NN Descent.
282+
Must be > `nnd_graph_degree`.
283+
284+
- `nnd_max_iterations` (int, default=20): Max NN Descent iterations.
285+
286+
- `nnd_termination_threshold` (float, default=0.0001): Stricter threshold leads to better convergence
287+
but longer runtime.
288+
289+
- `nnd_n_clusters` (int, default=1): Number of clusters for data partitioning.
290+
Higher values reduce memory usage at the cost of accuracy. When `nnd_n_clusters > 1`, data must be on host memory.
291+
Refer to data_on_host argument for fit_transform function.
292+
293+
- `nnd_overlap_factor` (int, default=2): Number of clusters each data point belongs to.
294+
Valid only when `nnd_n_clusters > 1`. Must be < 'nnd_n_clusters'.
295+
296+
Hints:
297+
298+
- Increasing `nnd_graph_degree` and `nnd_max_iterations` may improve accuracy.
299+
300+
- The ratio `nnd_overlap_factor / nnd_n_clusters` impacts memory usage.
301+
Approximately `(nnd_overlap_factor / nnd_n_clusters) * num_rows_in_entire_data` rows
302+
will be loaded onto device memory at once. E.g., 2/20 uses less device memory than 2/10.
303+
304+
- Larger `nnd_overlap_factor` results in better accuracy of the final knn graph.
305+
E.g. While using similar amount of device memory, `(nnd_overlap_factor / nnd_n_clusters)` = 4/20 will have better accuracy
306+
than 2/10 at the cost of performance.
307+
308+
- Start with `nnd_overlap_factor = 2` and gradually increase (2->3->4 ...) for better accuracy.
309+
310+
- Start with `nnd_n_clusters = 4` and increase (4 → 8 → 16...) for less GPU memory usage.
311+
This is independent from nnd_overlap_factor as long as 'nnd_overlap_factor' < 'nnd_n_clusters'.
280312
281313
Notes
282314
-----
@@ -478,18 +510,30 @@ class UMAP(UniversalBase,
478510

479511
if self.build_algo == "brute_force_knn":
480512
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
481-
else:
482-
umap_params.build_algo = graph_build_algo.NN_DESCENT
513+
elif self.build_algo == "nn_descent":
483514
build_kwds = self.build_kwds or {}
484-
umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
485-
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
486-
umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
487-
umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
488-
umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True)
489-
umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
490-
# Forward metric & metric_kwds to nn_descent
491-
umap_params.nn_descent_params.metric = <DistanceType> umap_params.metric
492-
umap_params.nn_descent_params.metric_arg = umap_params.p
515+
umap_params.build_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
516+
umap_params.build_params.overlap_factor = <uint64_t> build_kwds.get("nnd_overlap_factor", 2)
517+
if umap_params.build_params.n_clusters > 1 and umap_params.build_params.overlap_factor >= umap_params.build_params.n_clusters:
518+
raise ValueError("If nnd_n_clusters > 1, then nnd_overlap_factor must be strictly smaller than n_clusters.")
519+
if umap_params.build_params.n_clusters < 1:
520+
raise ValueError("nnd_n_clusters must be >= 1")
521+
umap_params.build_algo = graph_build_algo.NN_DESCENT
522+
523+
umap_params.build_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
524+
umap_params.build_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
525+
umap_params.build_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
526+
umap_params.build_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
527+
528+
if umap_params.build_params.nn_descent_params.graph_degree < self.n_neighbors:
529+
logger.warn("to use nn descent as the build algo, nnd_graph_degree should be larger than or equal to n_neigbors. setting nnd_graph_degree to n_neighbors.")
530+
umap_params.build_params.nn_descent_params.graph_degree = self.n_neighbors
531+
if umap_params.build_params.nn_descent_params.intermediate_graph_degree < umap_params.build_params.nn_descent_params.graph_degree:
532+
logger.warn("to use nn descent as the build algo, nnd_intermediate_graph_degree should be larger than or equal to nnd_graph_degree. \
533+
setting nnd_intermediate_graph_degree to nnd_graph_degree")
534+
umap_params.build_params.nn_descent_params.intermediate_graph_degree = umap_params.build_params.nn_descent_params.graph_degree
535+
else:
536+
raise ValueError(f"Unsupported value for `build_algo`: {self.build_algo}")
493537

494538
cdef uintptr_t callback_ptr = 0
495539
if self.callback:
@@ -528,6 +572,10 @@ class UMAP(UniversalBase,
528572
and also allows the use of a custom distance function. This function
529573
should match the metric used to train the UMAP embeedings.
530574
Takes precedence over the precomputed_knn parameter.
575+
576+
.. deprecated:: 25.06
577+
Using `nnd_n_clusters>1` with data on device is deprecated in version 25.06
578+
and will be removed in 25.08. Set `data_on_host=True` when `nnd_n_clusters>1`."
531579
"""
532580
if len(X.shape) != 2:
533581
raise ValueError("data should be two dimensional")
@@ -554,7 +602,16 @@ class UMAP(UniversalBase,
554602
if data_on_host:
555603
convert_to_mem_type = MemoryType.host
556604
else:
557-
convert_to_mem_type = MemoryType.device
605+
build_kwds = self.build_kwds or {}
606+
if build_kwds.get("nnd_n_clusters", 1) > 1:
607+
warnings.warn(
608+
("Using nnd_n_clusters>1 with data on device is deprecated in version 25.06"
609+
" and will be removed in 25.08. Set data_on_host=True when nnd_n_clusters>1."),
610+
FutureWarning,
611+
)
612+
convert_to_mem_type = MemoryType.host
613+
else:
614+
convert_to_mem_type = MemoryType.device
558615

559616
self._raw_data, self.n_rows, self.n_dims, _ = \
560617
input_to_cuml_array(X, order='C', check_dtype=np.float32,
@@ -707,6 +764,9 @@ class UMAP(UniversalBase,
707764
Acceptable formats: sparse SciPy ndarray, CuPy device ndarray,
708765
CSR/COO preferred other formats will go through conversion to CSR
709766

767+
.. deprecated:: 25.06
768+
Using `nnd_n_clusters>1` with data on device is deprecated in version 25.06
769+
and will be removed in 25.08. Set `data_on_host=True` when `nnd_n_clusters>1`."
710770
"""
711771
self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph, data_on_host=data_on_host)
712772

python/cuml/cuml/manifold/umap_utils.pxd

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals":
3939

4040
cdef cppclass GraphBasedDimRedCallback
4141

42-
43-
cdef extern from "cuvs/neighbors/nn_descent.hpp" namespace "cuvs::neighbors::nn_descent" nogil:
44-
cdef struct index_params:
45-
uint64_t graph_degree,
46-
uint64_t intermediate_graph_degree,
47-
uint64_t max_iterations,
48-
float termination_threshold,
49-
bool return_distances,
50-
uint64_t n_clusters,
51-
DistanceType metric,
52-
float metric_arg
42+
cdef extern from "cuml/manifold/umapparams.h" namespace "graph_build_params" nogil:
43+
cdef cppclass nn_descent_params_umap:
44+
size_t graph_degree
45+
size_t intermediate_graph_degree
46+
size_t max_iterations
47+
float termination_threshold
48+
49+
cdef cppclass graph_build_params:
50+
size_t overlap_factor
51+
size_t n_clusters
52+
nn_descent_params_umap nn_descent_params
5353

5454
cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
5555

@@ -71,6 +71,7 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
7171
float initial_alpha,
7272
int init,
7373
graph_build_algo build_algo,
74+
graph_build_params build_params,
7475
int target_n_neighbors,
7576
MetricType target_metric,
7677
float target_weight,
@@ -79,7 +80,6 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
7980
DistanceType metric,
8081
float p,
8182
GraphBasedDimRedCallback * callback,
82-
index_params nn_descent_params
8383

8484
cdef extern from "raft/sparse/coo.hpp" nogil:
8585
cdef cppclass COO "raft::sparse::COO<float, int>":

0 commit comments

Comments
 (0)