Skip to content

Multi-gpu KNN build for UMAP using all-neighbors API #6654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
39353cf
using all-neighbors from umap
jinsolp May 8, 2025
f7f580c
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 8, 2025
79cb4f3
pytest
jinsolp May 8, 2025
888dd7a
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 8, 2025
79960e1
addressing review
jinsolp May 12, 2025
34ea7e5
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 12, 2025
c627328
doc
jinsolp May 13, 2025
6a9f4ec
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 13, 2025
71a0bdf
comments
jinsolp May 14, 2025
f8ff1bb
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
3c9d340
Merge branch 'rapidsai:branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
1dae103
throwing error for n_clusters < 1
jinsolp May 14, 2025
c277f0f
comments
jinsolp May 14, 2025
9f6209c
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 14, 2025
1c5d238
tests
jinsolp May 15, 2025
1d7dccb
translate old params
jinsolp May 15, 2025
460c07a
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 15, 2025
0d1eb99
pin branch for testing
jinsolp May 15, 2025
9ffee27
deprecation warning and translation
jinsolp May 16, 2025
3d56d0b
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
b475e96
diff name declaration
jinsolp May 16, 2025
eb663d9
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 16, 2025
617233f
rev
jinsolp May 16, 2025
19b7e02
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
a002c3f
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 16, 2025
187cc73
back to non-breaking
jinsolp May 21, 2025
7417dfa
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 21, 2025
e2eb185
test kwds
jinsolp May 21, 2025
6a46378
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 21, 2025
006d8c6
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 23, 2025
e5dc732
revert cmake
jinsolp May 27, 2025
2ead5e7
revert cmake
jinsolp May 27, 2025
fcf45b2
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 27, 2025
e2b6b22
docs
jinsolp May 27, 2025
b8282b3
Merge branch 'branch-25.06' into umap-use-all-neighbors
jinsolp May 27, 2025
160708f
empty commit
jinsolp May 27, 2025
e05ffc1
Merge branch 'umap-use-all-neighbors' of https://github.com/jinsolp/c…
jinsolp May 27, 2025
ba2e4d9
change nearest_clusters->overlap_factor
jinsolp May 28, 2025
8195cab
deprecation putting data on host
jinsolp May 28, 2025
6d8662f
newline
jinsolp May 28, 2025
839e0ad
fix wording
jinsolp May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/cmake/thirdparty/get_cuvs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ function(find_and_configure_cuvs)
BUILD_EXPORT_SET cuml-exports
INSTALL_EXPORT_SET cuml-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/cuvs.git
GIT_TAG ${PKG_PINNED_TAG}
GIT_REPOSITORY https://github.com/jinsolp/cuvs.git
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert theses

GIT_TAG snmg-batching
SOURCE_SUBDIR cpp
EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL}
OPTIONS
Expand All @@ -73,8 +73,8 @@ endfunction()
# To use a different CUVS locally, set the CMake variable
# CPM_cuvs_SOURCE=/path/to/local/cuvs
find_and_configure_cuvs(VERSION ${CUML_MIN_VERSION_cuvs}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_cuvs}
FORK jinsolp
PINNED_TAG snmg-batching
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_CUVS_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local cuvs clone in build directory
Expand Down
53 changes: 51 additions & 2 deletions cpp/include/cuml/manifold/umapparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,56 @@

namespace ML {

using nn_index_params = cuvs::neighbors::nn_descent::index_params;
namespace graph_build_params {

/**
* Arguments for using nn descent as the knn build algorithm.
* graph_degree must be larger than or equal to n_neighbors.
* Increasing graph_degree and max_iterations may result in better accuracy.
*/
struct nn_descent_params_umap {
// not directly using cuvs::neighbors::nn_descent::index_params to distinguish UMAP-exposed NN
// Descent parameters
size_t graph_degree = 64;
size_t max_iterations = 20;

// These are deprecated in version 25.06 and will no longer be exposed starting 25.08
// related issue: github.com/rapidsai/cuml/issues/6742
size_t intermediate_graph_degree = 128;
float termination_threshold = 0.0001;
};

/**
* Parameters for knn graph building in UMAP.
* [Hint1]: the ratio of n_nearest_clusters / n_clusters determines device memory usage.
* Approximately (n_nearest_clusters / n_clusters) * num_rows_in_entire_data number of rows will be
* put on device memory at once. E.g. between (n_nearest_clusters / n_clusters) = 2/10 and 2/20, the
* latter will use less device memory.
* [Hint2]: larger n_nearest_clusters results in better accuracy
* of the final all-neighbors knn graph. E.g. While using similar amount of device memory,
* (n_nearest_clusters / n_clusters) = 4/20 will have better accuracy than 2/10 at the cost of
* performance.
* [Hint3]: for n_nearest_clusters, start with 2, and gradually increase (2->3->4 ...)
* for better accuracy
* [Hint4]: for n_clusters, start with 4, and gradually increase(4->8->16 ...)
* for less GPU memory usage. This is independent from n_nearest_clusters as long as
* n_nearest_clusters < n_clusters
*/
struct graph_build_params {
/**
* Number of clusters each data point is assigned to. Only valid when n_clusters > 1.
*/
size_t n_nearest_clusters = 2;
/**
* Number of clusters to split the data into when building the knn graph. Increasing this will use
* less device memory at the cost of accuracy. When using n_clusters > 1, is is required that the
* data is put on host (refer to data_on_host argument for fit_transform). The default value
* (n_clusters=1) will place the entire data on device memory.
*/
size_t n_clusters = 1;
nn_descent_params_umap nn_descent_params;
};
} // namespace graph_build_params

class UMAPParams {
public:
Expand Down Expand Up @@ -149,7 +198,7 @@ class UMAPParams {
*/
graph_build_algo build_algo = graph_build_algo::BRUTE_FORCE_KNN;

nn_index_params nn_descent_params = {};
graph_build_params::graph_build_params build_params;

/**
* The number of nearest neighbors to use to construct the target simplicial
Expand Down
113 changes: 52 additions & 61 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
#include <raft/util/cudart_utils.hpp>

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/all_neighbors.hpp>
#include <cuvs/neighbors/brute_force.hpp>
#include <cuvs/neighbors/nn_descent.hpp>
#include <stdint.h>

#include <iostream>
Expand All @@ -56,23 +56,6 @@ void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream);

auto get_graph_nnd(const raft::handle_t& handle,
const ML::manifold_dense_inputs_t<float>& inputs,
const ML::UMAPParams* params)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputs.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
if (ptr != nullptr) {
auto dataset =
raft::make_device_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
} else {
auto dataset = raft::make_host_matrix_view<const float, int64_t>(inputs.X, inputs.n, inputs.d);
return cuvs::neighbors::nn_descent::build(handle, params->nn_descent_params, dataset);
}
}

// Instantiation for dense inputs, int64_t indices
template <>
inline void launcher(const raft::handle_t& handle,
Expand All @@ -83,12 +66,14 @@ inline void launcher(const raft::handle_t& handle,
const ML::UMAPParams* params,
cudaStream_t stream)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
bool data_on_device = ptr != nullptr;

if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, inputsA.X));
float* ptr = reinterpret_cast<float*>(attr.devicePointer);
auto idx = [&]() {
if (ptr != nullptr) { // inputsA on device
auto idx = [&]() {
if (data_on_device) { // inputsA on device
return cuvs::neighbors::brute_force::build(
handle,
{params->metric, params->p},
Expand All @@ -107,45 +92,51 @@ inline void launcher(const raft::handle_t& handle,
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors),
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors));
} else { // nn_descent
// TODO: use nndescent from cuvs
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");
RAFT_EXPECTS(params->nn_descent_params.return_distances,
"return_distances for nn descent should be set to true to be used for UMAP");

auto graph = get_graph_nnd(handle, inputsA, params);

// `graph.graph()` is a host array (n x graph_degree).
// Slice and copy to a temporary host array (n x n_neighbors), then copy
// that to the output device array `out.knn_indices` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so the temporary host array and
// slice isn't necessary.
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
size_t graph_degree = params->nn_descent_params.graph_degree;
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
for (int j = 0; j < n_neighbors; j++) {
auto target = temp_indices_h.data_handle();
auto source = graph.graph().data_handle();
target[i * n_neighbors + j] = source[i * graph_degree + j];
}
RAFT_EXPECTS(
static_cast<size_t>(n_neighbors) <= params->build_params.nn_descent_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto all_neighbors_params = cuvs::neighbors::all_neighbors::all_neighbors_params{};
all_neighbors_params.n_nearest_clusters = params->build_params.n_nearest_clusters;
all_neighbors_params.n_clusters = params->build_params.n_clusters;
all_neighbors_params.metric = params->metric;

auto nn_descent_params =
cuvs::neighbors::all_neighbors::graph_build_params::nn_descent_params{};
nn_descent_params.graph_degree = params->build_params.nn_descent_params.graph_degree;
nn_descent_params.max_iterations = params->build_params.nn_descent_params.max_iterations;
nn_descent_params.metric = params->metric;

// nn_descent_params.intermediate_graph_degree = nn_descent_params.graph_degree * 1.5;
// TODO: These are deprecated in version 25.06 and will no longer be exposed starting 25.08
// related issue: github.com/rapidsai/cuml/issues/6742
nn_descent_params.intermediate_graph_degree =
params->build_params.nn_descent_params.intermediate_graph_degree;
nn_descent_params.termination_threshold =
params->build_params.nn_descent_params.termination_threshold;

all_neighbors_params.graph_build_params = nn_descent_params;

auto indices_view =
raft::make_device_matrix_view<int64_t, int64_t>(out.knn_indices, inputsB.n, n_neighbors);
auto distances_view =
raft::make_device_matrix_view<float, int64_t>(out.knn_dists, inputsB.n, n_neighbors);

if (data_on_device) { // inputsA on device
cuvs::neighbors::all_neighbors::build(
handle,
all_neighbors_params,
raft::make_device_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
indices_view,
distances_view);
} else { // inputsA on host
cuvs::neighbors::all_neighbors::build(
handle,
all_neighbors_params,
raft::make_host_matrix_view<const float, int64_t>(inputsA.X, inputsA.n, inputsA.d),
indices_view,
distances_view);
}
raft::copy(handle,
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
temp_indices_h.view());

// `graph.distances()` is a device array (n x graph_degree).
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(inputsA.n),
static_cast<int64_t>(n_neighbors)};
raft::matrix::slice<float, int64_t, raft::row_major>(
handle,
raft::make_const_mdspan(graph.distances().value()),
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
coords);
}
}

Expand Down
87 changes: 72 additions & 15 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,20 @@ class UMAP(UniversalBase,
'nn_descent']. 'auto' chooses to run with brute force knn if number of data rows is
smaller than or equal to 50K. Otherwise, runs with nn descent.
build_kwds: dict (optional, default=None)
Build algorithm argument {'nnd_graph_degree': 64, 'nnd_intermediate_graph_degree': 128,
'nnd_max_iterations': 20, 'nnd_termination_threshold': 0.0001, 'nnd_return_distances': True,
'nnd_n_clusters': 1}
Note that nnd_n_clusters > 1 will result in batch-building with NN Descent.
Build algorithm argument. Default values are: {'n_clusters': 1, 'n_nearest_clusters':2, 'nn_descent': {'graph_degree': n_neigbors, 'max_iterations': 20}}.
"n_clusters": int (default=1). Number of clusters to split the data into when building the knn graph. Increasing this will use less device memory at the cost of accuracy. When using n_clusters > 1, is is required that the data is put on host (refer to data_on_host argument for fit_transform). The default value (n_clusters=1) will place the entire data on device memory.
"n_nearest_clusters": int (default=2). Number of clusters each data point is assigned to. Only valid when n_clusters > 1.
"nn_descent": dict (default={"graph_degree": n_neighbors, "max_iterations": 20}). Arguments for when build_algo="nn_descent". graph_degree should be larger than or equal to n_neighbors. Increasing graph_degree and max_iterations may result in better accuracy.
[Hint1]: the ratio of n_nearest_clusters / n_clusters determines device memory usage. Approximately (n_nearest_clusters / n_clusters) * num_rows_in_entire_data number of rows will be put on device memory at once.
E.g. between (n_nearest_clusters / n_clusters) = 2/10 and 2/20, the latter will use less device memory.
[Hint2]: larger n_nearest_clusters results in better accuracy of the final all-neighbors knn graph.
E.g. While using similar amount of device memory, (n_nearest_clusters / n_clusters) = 4/20 will have better accuracy than 2/10 at the cost of performance.
[Hint3]: for n_nearest_clusters, start with 2, and gradually increase (2->3->4 ...) for better accuracy
[Hint4]: for n_clusters, start with 4, and gradually increase(4->8->16 ...) for less GPU memory usage. This is independent from n_nearest_clusters as long as n_nearest_clusters < n_clusters

.. deprecated:: 25.06
nnd_* arguments in build_kwds was deprecated in version 25.06 and will be
removed in 25.08. Please refer to the documentation for updated configuration guidance.

Notes
-----
Expand Down Expand Up @@ -430,6 +440,42 @@ class UMAP(UniversalBase,

self.build_kwds = build_kwds

self._handle_deprecated_build_kwds()

# TODO: remove deprecation handling logic in 25.08
# related issue: https://github.com/rapidsai/cuml/issues/6742
def _handle_deprecated_build_kwds(self):
if self.build_kwds is not None and any(key.startswith("nnd_") for key in self.build_kwds):
warnings.warn(
("nnd_* arguments in build_kwds was deprecated in version 25.06 and will be removed in 25.08. Please refer to the documentation for updated configuration guidance."),
FutureWarning,
)

nnd_kwds = self.build_kwds.get("nn_descent", None)
if nnd_kwds is not None:
# user passed deprecated nnd_* params and the new "nn_descent" dict params together
raise ValueError("Please remove nnd_* arguments in build_kwds and use the nn_descent dict parameters.")

# translating old nnd_ params in build kwds
graph_degree = self.build_kwds.pop("nnd_graph_degree", None)
max_iterations = self.build_kwds.pop("nnd_max_iterations", None)
n_clusters = self.build_kwds.pop("nnd_n_clusters", None)

# nnd related params
nn_descent_config = {}
if graph_degree is not None:
nn_descent_config["graph_degree"] = graph_degree
if max_iterations is not None:
nn_descent_config["max_iterations"] = max_iterations

if nn_descent_config:
self.build_kwds["nn_descent"] = nn_descent_config
if n_clusters is not None:
self.build_kwds["n_clusters"] = n_clusters

# removing nnd_return_distances from self.build_kwds. This has to be True for NN Descent to work with UMAP and should have never been exposed
self.build_kwds.pop("nnd_return_distances", None)

def validate_hyperparams(self):

if self.min_dist > self.spread:
Expand Down Expand Up @@ -478,18 +524,29 @@ class UMAP(UniversalBase,

if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else:
umap_params.build_algo = graph_build_algo.NN_DESCENT
elif self.build_algo == "nn_descent":
build_kwds = self.build_kwds or {}
umap_params.nn_descent_params.graph_degree = <uint64_t> build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> build_kwds.get("nnd_return_distances", True)
umap_params.nn_descent_params.n_clusters = <uint64_t> build_kwds.get("nnd_n_clusters", 1)
# Forward metric & metric_kwds to nn_descent
umap_params.nn_descent_params.metric = <DistanceType> umap_params.metric
umap_params.nn_descent_params.metric_arg = umap_params.p
umap_params.build_params.n_clusters = <uint64_t> build_kwds.get("n_clusters", 1)
umap_params.build_params.n_nearest_clusters = <uint64_t> build_kwds.get("n_nearest_clusters", 2)
if umap_params.build_params.n_clusters > 1 and umap_params.build_params.n_nearest_clusters >= umap_params.build_params.n_clusters:
raise ValueError("If n_clusters > 1, then n_nearest_clusters must be strictly smaller than n_clusters.")
if umap_params.build_params.n_clusters < 1:
raise ValueError("n_clusters must be >= 1")
umap_params.build_algo = graph_build_algo.NN_DESCENT

nnd_build_kwds = build_kwds.get("nn_descent", {})
umap_params.build_params.nn_descent_params.graph_degree = <uint64_t> nnd_build_kwds.get("graph_degree", self.n_neighbors)
umap_params.build_params.nn_descent_params.max_iterations = <uint64_t> nnd_build_kwds.get("max_iterations", 20)
if umap_params.build_params.nn_descent_params.graph_degree < self.n_neighbors:
logger.warn("to use nn descent as the build algo, graph_degree should be larger than or equal to n_neigbors. setting graph_degree to n_neighbors.")
umap_params.build_params.nn_descent_params.graph_degree = self.n_neighbors

# TODO: remove deprecation handling logic in 25.08
# related issue: https://github.com/rapidsai/cuml/issues/6742
umap_params.build_params.nn_descent_params.intermediate_graph_degree = <uint64_t> build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.build_params.nn_descent_params.termination_threshold = <float> build_kwds.get("nnd_termination_threshold", 0.0001)
else:
raise ValueError(f"Unsupported value for `build_algo`: {self.build_algo}")

cdef uintptr_t callback_ptr = 0
if self.callback:
Expand Down
25 changes: 14 additions & 11 deletions python/cuml/cuml/manifold/umap_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,20 @@ cdef extern from "cuml/common/callback.hpp" namespace "ML::Internals":

cdef cppclass GraphBasedDimRedCallback

cdef extern from "cuml/manifold/umapparams.h" namespace "graph_build_params" nogil:
cdef cppclass nn_descent_params_umap:
size_t graph_degree
size_t max_iterations

cdef extern from "cuvs/neighbors/nn_descent.hpp" namespace "cuvs::neighbors::nn_descent" nogil:
cdef struct index_params:
uint64_t graph_degree,
uint64_t intermediate_graph_degree,
uint64_t max_iterations,
float termination_threshold,
bool return_distances,
uint64_t n_clusters,
DistanceType metric,
float metric_arg
# TODO: remove deprecation handling logic in 25.08
# related issue: https://github.com/rapidsai/cuml/issues/6742
size_t intermediate_graph_degree
float termination_threshold

cdef cppclass graph_build_params:
size_t n_nearest_clusters
size_t n_clusters
nn_descent_params_umap nn_descent_params

cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:

Expand All @@ -71,6 +74,7 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
float initial_alpha,
int init,
graph_build_algo build_algo,
graph_build_params build_params,
int target_n_neighbors,
MetricType target_metric,
float target_weight,
Expand All @@ -79,7 +83,6 @@ cdef extern from "cuml/manifold/umapparams.h" namespace "ML" nogil:
DistanceType metric,
float p,
GraphBasedDimRedCallback * callback,
index_params nn_descent_params

cdef extern from "raft/sparse/coo.hpp" nogil:
cdef cppclass COO "raft::sparse::COO<float, int>":
Expand Down
Loading
Loading