Skip to content

Commit 8bf7057

Browse files
authored
Lanczos Solver which=SA,SM,LA,LM argument (#2628)
Resolves #2624 Resolves #2483 Authors: - Anupam (https://github.com/aamijar) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Micka (https://github.com/lowener) - Divye Gala (https://github.com/divyegala) URL: #2628
1 parent b25a7b0 commit 8bf7057

File tree

6 files changed

+588
-50
lines changed

6 files changed

+588
-50
lines changed

cpp/include/raft/sparse/solver/detail/lanczos.cuh

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include <raft/linalg/transpose.cuh>
5151
#include <raft/linalg/unary_op.cuh>
5252
#include <raft/matrix/diagonal.cuh>
53+
#include <raft/matrix/gather.cuh>
5354
#include <raft/matrix/matrix.cuh>
5455
#include <raft/matrix/slice.cuh>
5556
#include <raft/matrix/triangular.cuh>
@@ -63,6 +64,7 @@
6364
#include <raft/util/cudart_utils.hpp>
6465

6566
#include <cuda.h>
67+
#include <thrust/sort.h>
6668

6769
#include <cublasLt.h>
6870
#include <curand.h>
@@ -1507,10 +1509,15 @@ void lanczos_solve_ritz(
15071509
raft::device_matrix_view<ValueTypeT, uint32_t, raft::row_major> beta,
15081510
std::optional<raft::device_vector_view<ValueTypeT, uint32_t>> beta_k,
15091511
IndexTypeT k,
1510-
int which,
1512+
LANCZOS_WHICH which,
15111513
int ncv,
15121514
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors,
1513-
raft::device_vector_view<ValueTypeT> eigenvalues)
1515+
raft::device_vector_view<ValueTypeT> eigenvalues,
1516+
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major>& eigenvectors_k,
1517+
raft::device_vector_view<ValueTypeT, uint32_t>& eigenvalues_k,
1518+
raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>& eigenvectors_k_slice,
1519+
raft::device_vector_view<ValueTypeT> sm_eigenvalues,
1520+
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> sm_eigenvectors)
15141521
{
15151522
auto stream = resource::get_cuda_stream(handle);
15161523

@@ -1543,6 +1550,75 @@ void lanczos_solve_ritz(
15431550
triangular_matrix.data_handle(), ncv, ncv);
15441551

15451552
raft::linalg::eig_dc(handle, triangular_matrix_view, eigenvectors, eigenvalues);
1553+
1554+
IndexTypeT nEigVecs = k;
1555+
1556+
auto indices = raft::make_device_vector<int>(handle, ncv);
1557+
auto selected_indices = raft::make_device_vector<int>(handle, nEigVecs);
1558+
1559+
if (which == LANCZOS_WHICH::SA) {
1560+
eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
1561+
eigenvectors.data_handle(), ncv, nEigVecs);
1562+
eigenvalues_k =
1563+
raft::make_device_vector_view<ValueTypeT, uint32_t>(eigenvalues.data_handle(), nEigVecs);
1564+
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1565+
eigenvectors.data_handle(), ncv, nEigVecs);
1566+
} else if (which == LANCZOS_WHICH::LA) {
1567+
eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
1568+
eigenvectors.data_handle() + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
1569+
eigenvalues_k = raft::make_device_vector_view<ValueTypeT, uint32_t>(
1570+
eigenvalues.data_handle() + (ncv - nEigVecs), nEigVecs);
1571+
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1572+
eigenvectors.data_handle() + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
1573+
} else if (which == LANCZOS_WHICH::SM || which == LANCZOS_WHICH::LM) {
1574+
thrust::sequence(thrust::device, indices.data_handle(), indices.data_handle() + ncv, 0);
1575+
1576+
// Sort indices by absolute eigenvalues (magnitude) using a custom comparator
1577+
thrust::sort(thrust::device,
1578+
indices.data_handle(),
1579+
indices.data_handle() + ncv,
1580+
[eigenvalues = eigenvalues.data_handle()] __device__(int a, int b) {
1581+
return fabsf(eigenvalues[a]) < fabsf(eigenvalues[b]);
1582+
});
1583+
1584+
if (which == LANCZOS_WHICH::SM) {
1585+
// Take the first nEigVecs indices (smallest magnitude)
1586+
raft::copy(selected_indices.data_handle(), indices.data_handle(), nEigVecs, stream);
1587+
} else if (which == LANCZOS_WHICH::LM) {
1588+
// Take the last nEigVecs indices (largest magnitude)
1589+
raft::copy(
1590+
selected_indices.data_handle(), indices.data_handle() + (ncv - nEigVecs), nEigVecs, stream);
1591+
}
1592+
1593+
// Re-sort these indices by algebraic value to maintain algebraic ordering
1594+
thrust::sort(thrust::device,
1595+
selected_indices.data_handle(),
1596+
selected_indices.data_handle() + nEigVecs,
1597+
[eigenvalues = eigenvalues.data_handle()] __device__(int a, int b) {
1598+
return eigenvalues[a] < eigenvalues[b];
1599+
});
1600+
raft::matrix::gather(
1601+
handle,
1602+
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::row_major>(
1603+
eigenvalues.data_handle(), ncv, 1),
1604+
raft::make_device_vector_view<const int, uint32_t>(selected_indices.data_handle(), nEigVecs),
1605+
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
1606+
sm_eigenvalues.data_handle(), nEigVecs, 1));
1607+
raft::matrix::gather(
1608+
handle,
1609+
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::row_major>(
1610+
eigenvectors.data_handle(), ncv, ncv),
1611+
raft::make_device_vector_view<const int, uint32_t>(selected_indices.data_handle(), nEigVecs),
1612+
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
1613+
sm_eigenvectors.data_handle(), nEigVecs, ncv));
1614+
1615+
eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
1616+
sm_eigenvectors.data_handle(), ncv, nEigVecs);
1617+
eigenvalues_k =
1618+
raft::make_device_vector_view<ValueTypeT, uint32_t>(sm_eigenvalues.data_handle(), nEigVecs);
1619+
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1620+
sm_eigenvectors.data_handle(), ncv, nEigVecs);
1621+
}
15461622
}
15471623

15481624
template <typename IndexTypeT, typename ValueTypeT>
@@ -1695,6 +1771,7 @@ auto lanczos_smallest(
16951771
int maxIter,
16961772
int restartIter,
16971773
ValueTypeT tol,
1774+
LANCZOS_WHICH which,
16981775
ValueTypeT* eigVals_dev,
16991776
ValueTypeT* eigVecs_dev,
17001777
ValueTypeT* v0,
@@ -1752,20 +1829,28 @@ auto lanczos_smallest(
17521829
raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, ncv);
17531830
auto eigenvalues = raft::make_device_vector<ValueTypeT, uint32_t>(handle, ncv);
17541831

1832+
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors_k;
1833+
raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues_k;
1834+
raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major> eigenvectors_k_slice;
1835+
1836+
auto sm_eigenvalues = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
1837+
auto sm_eigenvectors =
1838+
raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, nEigVecs);
1839+
17551840
lanczos_solve_ritz<IndexTypeT, ValueTypeT>(handle,
17561841
alpha.view(),
17571842
beta.view(),
17581843
std::nullopt,
17591844
nEigVecs,
1760-
0,
1845+
which,
17611846
ncv,
17621847
eigenvectors.view(),
1763-
eigenvalues.view());
1764-
1765-
auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
1766-
eigenvectors.data_handle(), ncv, nEigVecs);
1767-
auto eigenvalues_k =
1768-
raft::make_device_vector_view<ValueTypeT, uint32_t>(eigenvalues.data_handle(), nEigVecs);
1848+
eigenvalues.view(),
1849+
eigenvectors_k,
1850+
eigenvalues_k,
1851+
eigenvectors_k_slice,
1852+
sm_eigenvalues.view(),
1853+
sm_eigenvectors.view());
17691854

17701855
auto ritz_eigenvectors =
17711856
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(eigVecs_dev, n, nEigVecs);
@@ -1777,9 +1862,6 @@ auto lanczos_smallest(
17771862

17781863
auto s = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
17791864

1780-
auto eigenvectors_k_slice =
1781-
raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1782-
eigenvectors.data_handle(), ncv, nEigVecs);
17831865
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
17841866
s.data_handle(), 1, nEigVecs);
17851867

@@ -2002,12 +2084,15 @@ auto lanczos_smallest(
20022084
beta.view(),
20032085
beta_k.view(),
20042086
nEigVecs,
2005-
0,
2087+
which,
20062088
ncv,
20072089
eigenvectors.view(),
2008-
eigenvalues.view());
2009-
auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
2010-
eigenvectors.data_handle(), ncv, nEigVecs);
2090+
eigenvalues.view(),
2091+
eigenvectors_k,
2092+
eigenvalues_k,
2093+
eigenvectors_k_slice,
2094+
sm_eigenvalues.view(),
2095+
sm_eigenvectors.view());
20112096

20122097
auto ritz_eigenvectors = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
20132098
eigVecs_dev, n, nEigVecs);
@@ -2017,9 +2102,6 @@ auto lanczos_smallest(
20172102
raft::linalg::gemm<ValueTypeT, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
20182103
handle, V_T, eigenvectors_k, ritz_eigenvectors);
20192104

2020-
auto eigenvectors_k_slice =
2021-
raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
2022-
eigenvectors.data_handle(), ncv, nEigVecs);
20232105
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
20242106
s.data_handle(), 1, nEigVecs);
20252107

@@ -2066,6 +2148,7 @@ auto lanczos_compute_smallest_eigenvectors(
20662148
config.max_iterations,
20672149
config.ncv,
20682150
config.tolerance,
2151+
config.which,
20692152
eigenvalues.data_handle(),
20702153
eigenvectors.data_handle(),
20712154
v0->data_handle(),
@@ -2082,6 +2165,7 @@ auto lanczos_compute_smallest_eigenvectors(
20822165
config.max_iterations,
20832166
config.ncv,
20842167
config.tolerance,
2168+
config.which,
20852169
eigenvalues.data_handle(),
20862170
eigenvectors.data_handle(),
20872171
temp_v0.data_handle(),

cpp/include/raft/sparse/solver/lanczos_types.hpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,57 @@
2020

2121
namespace raft::sparse::solver {
2222

23+
/**
24+
* @enum LANCZOS_WHICH
25+
* @brief Enumeration specifying which eigenvalues to compute in the Lanczos algorithm
26+
*/
27+
enum LANCZOS_WHICH {
28+
/** @brief LA: Largest (algebraic) eigenvalues */
29+
LA,
30+
/** @brief LM: Largest (in magnitude) eigenvalues */
31+
LM,
32+
/** @brief SA: Smallest (algebraic) eigenvalues */
33+
SA,
34+
/** @brief SM: Smallest (in magnitude) eigenvalues */
35+
SM
36+
};
37+
38+
/**
39+
* @brief Configuration parameters for the Lanczos eigensolver
40+
*
41+
* This structure encapsulates all configuration parameters needed to run the
42+
* Lanczos algorithm for computing eigenvalues and eigenvectors of large sparse matrices.
43+
*
44+
* @tparam ValueTypeT Data type for values (float or double)
45+
*/
2346
template <typename ValueTypeT>
2447
struct lanczos_solver_config {
25-
/** The number of eigenvalues and eigenvectors to compute. Must be 1 <= k < n.*/
48+
/** @brief The number of eigenvalues and eigenvectors to compute
49+
* @note Must be 1 <= n_components < n, where n is the matrix dimension
50+
*/
2651
int n_components;
27-
/** Maximum number of iteration. */
52+
53+
/** @brief Maximum number of iterations allowed for the algorithm to converge */
2854
int max_iterations;
29-
/** The number of Lanczos vectors generated. Must be k + 1 < ncv < n. */
55+
56+
/** @brief The number of Lanczos vectors to generate
57+
* @note Must satisfy n_components + 1 < ncv < n, where n is the matrix dimension
58+
*/
3059
int ncv;
31-
/** Tolerance for residuals ``||Ax - wx||`` */
60+
61+
/** @brief Convergence tolerance for residuals
62+
* @note Used to determine when to stop iteration based on ||Ax - wx|| < tolerance
63+
*/
3264
ValueTypeT tolerance;
33-
/** random seed */
65+
66+
/** @brief Specifies which eigenvalues to compute in the Lanczos algorithm
67+
* @see LANCZOS_WHICH for possible values (SA, LA, SM, LM)
68+
*/
69+
LANCZOS_WHICH which;
70+
71+
/** @brief Random seed for initialization of the algorithm
72+
* @note Controls reproducibility of results
73+
*/
3474
uint64_t seed;
3575
};
3676

cpp/include/raft/spectral/eigen_solvers.cuh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ struct eigen_solver_config_t {
4040
1234567}; // CAVEAT: this default value is now common to all instances of using seed in
4141
// Lanczos; was not the case before: there were places where a default seed = 123456
4242
// was used; this may trigger slightly different # solver iterations
43+
44+
raft::sparse::solver::LANCZOS_WHICH which{raft::sparse::solver::LANCZOS_WHICH::SA};
4345
};
4446

4547
template <typename index_type_t, typename value_type_t, typename size_type_t = index_type_t>
@@ -79,8 +81,13 @@ struct lanczos_solver_t {
7981
RAFT_EXPECTS(eigVals != nullptr, "Null eigVals buffer.");
8082
RAFT_EXPECTS(eigVecs != nullptr, "Null eigVecs buffer.");
8183

82-
auto lanczos_config = raft::sparse::solver::lanczos_solver_config<value_type_t>{
83-
config_.n_eigVecs, config_.maxIter, config_.restartIter, config_.tol, config_.seed};
84+
auto lanczos_config =
85+
raft::sparse::solver::lanczos_solver_config<value_type_t>{config_.n_eigVecs,
86+
config_.maxIter,
87+
config_.restartIter,
88+
config_.tol,
89+
config_.which,
90+
config_.seed};
8491
auto v0_opt = std::optional<raft::device_vector_view<value_type_t, uint32_t, raft::row_major>>{
8592
std::nullopt};
8693
auto input_structure = input.structure_view();

0 commit comments

Comments
 (0)