Skip to content

Multi-gpu KNN build for UMAP using all-neighbors API #6654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
39353cf
using all-neighbors from umap
jinsolp May 8, 2025
f7f580c
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 8, 2025
79cb4f3
pytest
jinsolp May 8, 2025
888dd7a
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 8, 2025
79960e1
addressing review
jinsolp May 12, 2025
34ea7e5
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 12, 2025
c627328
doc
jinsolp May 13, 2025
6a9f4ec
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 13, 2025
71a0bdf
comments
jinsolp May 14, 2025
f8ff1bb
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
3c9d340
Merge branch 'rapidsai:branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
1dae103
throwing error for n_clusters < 1
jinsolp May 14, 2025
c277f0f
comments
jinsolp May 14, 2025
9f6209c
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
1c5d238
tests
jinsolp May 15, 2025
1d7dccb
translate old params
jinsolp May 15, 2025
460c07a
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 15, 2025
0d1eb99
pin branch for testing
jinsolp May 15, 2025
9ffee27
deprecation warning and translation
jinsolp May 16, 2025
3d56d0b
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
b475e96
diff name declaration
jinsolp May 16, 2025
eb663d9
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 16, 2025
617233f
rev
jinsolp May 16, 2025
19b7e02
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
a002c3f
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
187cc73
back to non-breaking
jinsolp May 21, 2025
7417dfa
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 21, 2025
e2eb185
test kwds
jinsolp May 21, 2025
6a46378
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 21, 2025
006d8c6
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 23, 2025
e5dc732
revert cmake
jinsolp May 27, 2025
2ead5e7
revert cmake
jinsolp May 27, 2025
fcf45b2
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 27, 2025
e2b6b22
docs
jinsolp May 27, 2025
b8282b3
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 27, 2025
160708f
empty commit
jinsolp May 27, 2025
e05ffc1
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 27, 2025
ba2e4d9
change nearest_clusters->overlap_factor
jinsolp May 28, 2025
8195cab
deprecation putting data on host
jinsolp May 28, 2025
6d8662f
newline
jinsolp May 28, 2025
839e0ad
fix wording
jinsolp May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,55 @@

namespace ML {

using nn_index_params = cuvs::neighbors::nn_descent::index_params;
namespace graph_build_params {

/**
* Arguments for using nn descent as the knn build algorithm.
* graph_degree must be larger than or equal to n_neighbors.
* Increasing graph_degree and max_iterations may result in better accuracy.
* Smaller termination threshold means stricter convergence criteria for nn descent and may take
* longer to converge.
*/
struct nn_descent_params_umap {
// not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
// Descent parameters
size_t graph_degree = 64;
size_t intermediate_graph_degree = 128;
size_t max_iterations = 20;
float termination_threshold = 0.0001;
};

/**
* Parameters for knn graph building in UMAP.
* [Hint1]: the ratio of overlap_factor / n_clusters determines device memory usage.
* Approximately (overlap_factor / n_clusters) * num_rows_in_entire_data number of rows will be
* put on device memory at once. E.g. between (overlap_factor / n_clusters) = 2/10 and 2/20, the
* latter will use less device memory.
* [Hint2]: larger overlap_factor results in better accuracy
* of the final all-neighbors knn graph. E.g. While using similar amount of device memory,
* (overlap_factor / n_clusters) = 4/20 will have better accuracy than 2/10 at the cost of
* performance.
* [Hint3]: for overlap_factor, start with 2, and gradually increase (2->3->4 ...)
* for better accuracy
* [Hint4]: for n_clusters, start with 4, and gradually increase(4->8->16 ...)
* for less GPU memory usage. This is independent from overlap_factor as long as
* overlap_factor < n_clusters
*/
struct graph_build_params {
/**
* Number of clusters each data point is assigned to. Only valid when n_clusters > 1.
*/
size_t overlap_factor = 2;
/**
* Number of clusters to split the data into when building the knn graph. Increasing this will use
* less device memory at the cost of accuracy. When using n_clusters > 1, is is required that the
* data is put on host (refer to data_on_host argument for fit_transform). The default value
* (n_clusters=1) will place the entire data on device memory.
*/
size_t n_clusters = 1;
nn_descent_params_umap nn_descent_params;
};
} // namespace graph_build_params

class UMAPParams {
public:
Expand Down Expand Up @@ -149,7 +197,7 @@ class UMAPParams {
*/
graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;

nn_index_params nn_descent_params = {};
graph_build_params::graph_build_params build_params;

/**
* The number of nearest neighbors to use to construct the target simplicial
Expand Down
112 changes: 51 additions & 61 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
#include <raft/util/cudart_utils.hpp>

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/all_neighbors.hpp>
#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/nn_descent.hpp>
#include <stdint.h>

#include <iostream>
Expand All @@ -56,23 +56,6 @@ void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream);

auto get_graph_nnd(const raft::handle_t& handle,
const ML::manifold_dense_inputs_t<float>& inputs,
const ML::UMAPParams* params)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
if (ptr != nullptr) {
auto dataset =
raft::make_device_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
} else {
auto dataset = raft::make_host_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
}
}

// Instantiation for dense inputs, int64_t indices
template <>
inline void launcher(const raft::handle_t& handle,
Expand All @@ -83,12 +66,14 @@ inline void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
bool data_on_device = ptr != nullptr;

if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
auto idx = [&]() {
if (ptr != nullptr) { // inputsA on device
auto idx = [&]() {
if (data_on_device) { // inputsA on device
return cuvs::neighbors::brute_force::build(
handle,
{params->metric, params->p},
Expand All @@ -107,45 +92,50 @@ inline void launcher(const raft::handle_t& handle,
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors),
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors));
} else { // nn_descent
// TODO: use nndescent from cuvs
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");
RAFT_EXPECTS(params->nn_descent_params.return_distances,
"return_distances for nn descent should be set to true to be used for UMAP");

auto graph = get_graph_nnd(handle, inputsA, params);

// `graph.graph()` is a host array (n x graph_degree).
// Slice and copy to a temporary host array (n x n_neighbors), then copy
// that to the output device array `out.knn_indices` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so the temporary host array and
// slice isn't necessary.
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
size_t graph_degree = params->nn_descent_params.graph_degree;
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
for (int j = 0; j < n_neighbors; j++) {
auto target = temp_indices_h.data_handle();
auto source = graph.graph().data_handle();
target[i * n_neighbors + j] = source[i * graph_degree + j];
}
RAFT_EXPECTS(
static_cast<size_t>(n_neighbors) <= params->build_params.nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");
RAFT_EXPECTS(
params->build_params.nn_descent_params.graph_degree <=
params->build_params.nn_descent_params.intermediate_graph_degree,
"graph_degree should be smaller than intermediate_graph_degree computed by nn descent");

auto all_neighbors_params = cuvs::neighbors::all_neighbors::all_neighbors_params{};
all_neighbors_params.overlap_factor = params->build_params.overlap_factor;
all_neighbors_params.n_clusters = params->build_params.n_clusters;
all_neighbors_params.metric = params->metric;

auto nn_descent_params =
cuvs::neighbors::all_neighbors::graph_build_params::nn_descent_params{};
nn_descent_params.graph_degree = params->build_params.nn_descent_params.graph_degree;
nn_descent_params.intermediate_graph_degree =
params->build_params.nn_descent_params.intermediate_graph_degree;
nn_descent_params.max_iterations = params->build_params.nn_descent_params.max_iterations;
nn_descent_params.termination_threshold =
params->build_params.nn_descent_params.termination_threshold;
nn_descent_params.metric = params->metric;
all_neighbors_params.graph_build_params = nn_descent_params;

auto indices_view =
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors);
auto distances_view =
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors);

if (data_on_device) { // inputsA on device
cuvs::neighbors::all_neighbors::build(
handle,
all_neighbors_params,
raft::make_device_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
indices_view,
distances_view);
} else { // inputsA on host
cuvs::neighbors::all_neighbors::build(
handle,
all_neighbors_params,
raft::make_host_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
indices_view,
distances_view);
}
raft::copy(handle,
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
temp_indices_h.view());

// `graph.distances()` is a device array (n x graph_degree).
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(inputsA.n),
static_cast<int64_t>(n_neighbors)};
raft::matrix::slice<float, int64_t, raft::row_major>(
handle,
raft::make_const_mdspan(graph.distances().value()),
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
coords);
}
}

Expand Down
74 changes: 59 additions & 15 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,42 @@ class UMAP(UniversalBase,
'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is
smaller than or equal to 50K. Otherwise, runs with nn descent.
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
'nnd_n_clusters': 1}
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
Dictionary of parameters to configure the build algorithm. Default values:

- `nnd_graph_degree` (int, default=64): Graph degree used for NN Descent.
Must be ≥ `n_neighbors`.

- `nnd_intermediate_graph_degree` (int, default=128): Intermediate graph degree for NN Descent.
Must be > `nnd_graph_degree`.

- `nnd_max_iterations` (int, default=20): Max NN Descent iterations.

- `nnd_termination_threshold` (float, default=0.0001): Stricter threshold leads to better convergence
but longer runtime.

- `nnd_n_clusters` (int, default=1): Number of clusters for data partitioning.
Higher values reduce memory usage at the cost of accuracy. When `nnd_n_clusters > 1`, data must be on host memory.
Refer to data_on_host argument for fit_transform function.

- `nnd_overlap_factor` (int, default=2): Number of clusters each data point belongs to.
Valid only when `nnd_n_clusters > 1`. Must be < 'nnd_n_clusters'.

Hints:

- Increasing `nnd_graph_degree` and `nnd_max_iterations` may improve accuracy.

- The ratio `nnd_overlap_factor / nnd_n_clusters` impacts memory usage.
Approximately `(nnd_overlap_factor / nnd_n_clusters) * num_rows_in_entire_data` rows
will be loaded onto device memory at once. E.g., 2/20 uses less device memory than 2/10.

- Larger `nnd_overlap_factor` results in better accuracy of the final knn graph.
E.g. While using similar amount of device memory, `(nnd_overlap_factor / nnd_n_clusters)` = 4/20 will have better accuracy
than 2/10 at the cost of performance.

- Start with `nnd_overlap_factor = 2` and gradually increase (2->3->4 ...) for better accuracy.

- Start with `nnd_n_clusters = 4` and increase (4 → 8 → 16...) for less GPU memory usage.
This is independent from nnd_overlap_factor as long as 'nnd_overlap_factor' < 'nnd_n_clusters'.

Notes
-----
Expand Down Expand Up @@ -478,18 +510,30 @@ class UMAP(UniversalBase,

if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else:
umap_params.build_algo = graph_build_algo.NN_DESCENT
elif self.build_algo == "nn_descent":
build_kwds = self.build_kwds or {}
umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True)
umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
# Forward metric & metric_kwds to nn_descent
umap_params.nn_descent_params.metric = <DistanceType> umap_params.metric
umap_params.nn_descent_params.metric_arg = umap_params.p
umap_params.build_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
umap_params.build_params.overlap_factor = <uint64_t> build_kwds.get("nnd_overlap_factor", 2)
if umap_params.build_params.n_clusters > 1 and umap_params.build_params.overlap_factor >= umap_params.build_params.n_clusters:
raise ValueError("If nnd_n_clusters > 1, then nnd_overlap_factor must be strictly smaller than n_clusters.")
if umap_params.build_params.n_clusters < 1:
raise ValueError("nnd_n_clusters must be >= 1")
umap_params.build_algo = graph_build_algo.NN_DESCENT

umap_params.build_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
umap_params.build_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.build_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
umap_params.build_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)

if umap_params.build_params.nn_descent_params.graph_degree < self.n_neighbors:
logger.warn("to use nn descent as the build algo, nnd_graph_degree should be larger than or equal to n_neigbors. setting nnd_graph_degree to n_neighbors.")
umap_params.build_params.nn_descent_params.graph_degree = self.n_neighbors
if umap_params.build_params.nn_descent_params.intermediate_graph_degree < umap_params.build_params.nn_descent_params.graph_degree:
logger.warn("to use nn descent as the build algo, nnd_intermediate_graph_degree should be larger than or equal to nnd_graph_degree. \
setting nnd_intermediate_graph_degree to nnd_graph_degree")
umap_params.build_params.nn_descent_params.intermediate_graph_degree = umap_params.build_params.nn_descent_params.graph_degree
else:
raise ValueError(f"Unsupported value for `build_algo`: {self.build_algo}")

cdef uintptr_t callback_ptr = 0
if self.callback:
Expand Down
24 changes: 12 additions & 12 deletions python/cuml/cuml/manifold/umap_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals":

cdef cppclass GraphBasedDimRedCallback


cdef extern from "cuvs/neighbors/nn_descent.hpp" namespace "cuvs::neighbors::nn_descent" nogil:
cdef struct index_params:
uint64_t graph_degree,
uint64_t intermediate_graph_degree,
uint64_t max_iterations,
float termination_threshold,
bool return_distances,
uint64_t n_clusters,
DistanceType metric,
float metric_arg
cdef extern from "cuml/manifold/umapparams.h" namespace "graph_build_params" nogil:
cdef cppclass nn_descent_params_umap:
size_t graph_degree
size_t intermediate_graph_degree
size_t max_iterations
float termination_threshold

cdef cppclass graph_build_params:
size_t overlap_factor
size_t n_clusters
nn_descent_params_umap nn_descent_params

cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:

Expand All @@ -71,6 +71,7 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
float initial_alpha,
int init,
graph_build_algo build_algo,
graph_build_params build_params,
int target_n_neighbors,
MetricType target_metric,
float target_weight,
Expand All @@ -79,7 +80,6 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
DistanceType metric,
float p,
GraphBasedDimRedCallback * callback,
index_params nn_descent_params

cdef extern from "raft/sparse/coo.hpp" nogil:
cdef cppclass COO "raft::sparse::COO<float, int>":
Expand Down
Loading
Loading