33
33
#include < raft/util/cudart_utils.hpp>
34
34
35
35
#include < cuvs/distance/distance.hpp>
36
+ #include < cuvs/neighbors/all_neighbors.hpp>
36
37
#include < cuvs/neighbors/brute_force.hpp>
37
- #include < cuvs/neighbors/nn_descent.hpp>
38
38
#include < stdint.h>
39
39
40
40
#include < iostream>
@@ -56,23 +56,6 @@ void launcher(const raft::handle_t& handle,
56
56
const ML::UMAPParams* params,
57
57
cudaStream_t stream);
58
58
59
- auto get_graph_nnd (const raft::handle_t & handle,
60
- const ML::manifold_dense_inputs_t <float >& inputs,
61
- const ML::UMAPParams* params)
62
- {
63
- cudaPointerAttributes attr;
64
- RAFT_CUDA_TRY (cudaPointerGetAttributes (&attr, inputs.X ));
65
- float * ptr = reinterpret_cast <float *>(attr.devicePointer );
66
- if (ptr != nullptr ) {
67
- auto dataset =
68
- raft::make_device_matrix_view<const float , int64_t >(inputs.X , inputs.n , inputs.d );
69
- return cuvs::neighbors::nn_descent::build (handle, params->nn_descent_params , dataset);
70
- } else {
71
- auto dataset = raft::make_host_matrix_view<const float , int64_t >(inputs.X , inputs.n , inputs.d );
72
- return cuvs::neighbors::nn_descent::build (handle, params->nn_descent_params , dataset);
73
- }
74
- }
75
-
76
59
// Instantiation for dense inputs, int64_t indices
77
60
template <>
78
61
inline void launcher (const raft::handle_t & handle,
@@ -83,12 +66,14 @@ inline void launcher(const raft::handle_t& handle,
83
66
const ML::UMAPParams* params,
84
67
cudaStream_t stream)
85
68
{
69
+ cudaPointerAttributes attr;
70
+ RAFT_CUDA_TRY (cudaPointerGetAttributes (&attr, inputsA.X ));
71
+ float * ptr = reinterpret_cast <float *>(attr.devicePointer );
72
+ bool data_on_device = ptr != nullptr ;
73
+
86
74
if (params->build_algo == ML::UMAPParams::graph_build_algo::BRUTE_FORCE_KNN) {
87
- cudaPointerAttributes attr;
88
- RAFT_CUDA_TRY (cudaPointerGetAttributes (&attr, inputsA.X ));
89
- float * ptr = reinterpret_cast <float *>(attr.devicePointer );
90
- auto idx = [&]() {
91
- if (ptr != nullptr ) { // inputsA on device
75
+ auto idx = [&]() {
76
+ if (data_on_device) { // inputsA on device
92
77
return cuvs::neighbors::brute_force::build (
93
78
handle,
94
79
{params->metric , params->p },
@@ -107,45 +92,50 @@ inline void launcher(const raft::handle_t& handle,
107
92
raft::make_device_matrix_view<int64_t , int64_t >(out.knn_indices , inputsB.n , n_neighbors),
108
93
raft::make_device_matrix_view<float , int64_t >(out.knn_dists , inputsB.n , n_neighbors));
109
94
} else { // nn_descent
110
- // TODO: use nndescent from cuvs
111
- RAFT_EXPECTS (static_cast <size_t >(n_neighbors) <= params->nn_descent_params .graph_degree ,
112
- " n_neighbors should be smaller than the graph degree computed by nn descent" );
113
- RAFT_EXPECTS (params->nn_descent_params .return_distances ,
114
- " return_distances for nn descent should be set to true to be used for UMAP" );
115
-
116
- auto graph = get_graph_nnd (handle, inputsA, params);
117
-
118
- // `graph.graph()` is a host array (n x graph_degree).
119
- // Slice and copy to a temporary host array (n x n_neighbors), then copy
120
- // that to the output device array `out.knn_indices` (n x n_neighbors).
121
- // TODO: force graph_degree = n_neighbors so the temporary host array and
122
- // slice isn't necessary.
123
- auto temp_indices_h = raft::make_host_matrix<int64_t , int64_t >(inputsA.n , n_neighbors);
124
- size_t graph_degree = params->nn_descent_params .graph_degree ;
125
- #pragma omp parallel for
126
- for (size_t i = 0 ; i < static_cast <size_t >(inputsA.n ); i++) {
127
- for (int j = 0 ; j < n_neighbors; j++) {
128
- auto target = temp_indices_h.data_handle ();
129
- auto source = graph.graph ().data_handle ();
130
- target[i * n_neighbors + j] = source[i * graph_degree + j];
131
- }
95
+ RAFT_EXPECTS (
96
+ static_cast <size_t >(n_neighbors) <= params->build_params .nn_descent_params .graph_degree ,
97
+ " n_neighbors should be smaller than the graph degree computed by nn descent" );
98
+ RAFT_EXPECTS (
99
+ params->build_params .nn_descent_params .graph_degree <=
100
+ params->build_params .nn_descent_params .intermediate_graph_degree ,
101
+ " graph_degree should be smaller than intermediate_graph_degree computed by nn descent" );
102
+
103
+ auto all_neighbors_params = cuvs::neighbors::all_neighbors::all_neighbors_params{};
104
+ all_neighbors_params.overlap_factor = params->build_params .overlap_factor ;
105
+ all_neighbors_params.n_clusters = params->build_params .n_clusters ;
106
+ all_neighbors_params.metric = params->metric ;
107
+
108
+ auto nn_descent_params =
109
+ cuvs::neighbors::all_neighbors::graph_build_params::nn_descent_params{};
110
+ nn_descent_params.graph_degree = params->build_params .nn_descent_params .graph_degree ;
111
+ nn_descent_params.intermediate_graph_degree =
112
+ params->build_params .nn_descent_params .intermediate_graph_degree ;
113
+ nn_descent_params.max_iterations = params->build_params .nn_descent_params .max_iterations ;
114
+ nn_descent_params.termination_threshold =
115
+ params->build_params .nn_descent_params .termination_threshold ;
116
+ nn_descent_params.metric = params->metric ;
117
+ all_neighbors_params.graph_build_params = nn_descent_params;
118
+
119
+ auto indices_view =
120
+ raft::make_device_matrix_view<int64_t , int64_t >(out.knn_indices , inputsB.n , n_neighbors);
121
+ auto distances_view =
122
+ raft::make_device_matrix_view<float , int64_t >(out.knn_dists , inputsB.n , n_neighbors);
123
+
124
+ if (data_on_device) { // inputsA on device
125
+ cuvs::neighbors::all_neighbors::build (
126
+ handle,
127
+ all_neighbors_params,
128
+ raft::make_device_matrix_view<const float , int64_t >(inputsA.X , inputsA.n , inputsA.d ),
129
+ indices_view,
130
+ distances_view);
131
+ } else { // inputsA on host
132
+ cuvs::neighbors::all_neighbors::build (
133
+ handle,
134
+ all_neighbors_params,
135
+ raft::make_host_matrix_view<const float , int64_t >(inputsA.X , inputsA.n , inputsA.d ),
136
+ indices_view,
137
+ distances_view);
132
138
}
133
- raft::copy (handle,
134
- raft::make_device_matrix_view (out.knn_indices , inputsA.n , n_neighbors),
135
- temp_indices_h.view ());
136
-
137
- // `graph.distances()` is a device array (n x graph_degree).
138
- // Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
139
- // TODO: force graph_degree = n_neighbors so this slice isn't necessary.
140
- raft::matrix::slice_coordinates coords{static_cast <int64_t >(0 ),
141
- static_cast <int64_t >(0 ),
142
- static_cast <int64_t >(inputsA.n ),
143
- static_cast <int64_t >(n_neighbors)};
144
- raft::matrix::slice<float , int64_t , raft::row_major>(
145
- handle,
146
- raft::make_const_mdspan (graph.distances ().value ()),
147
- raft::make_device_matrix_view (out.knn_dists , inputsA.n , n_neighbors),
148
- coords);
149
139
}
150
140
}
151
141
0 commit comments