50
50
#include < raft/linalg/transpose.cuh>
51
51
#include < raft/linalg/unary_op.cuh>
52
52
#include < raft/matrix/diagonal.cuh>
53
+ #include < raft/matrix/gather.cuh>
53
54
#include < raft/matrix/matrix.cuh>
54
55
#include < raft/matrix/slice.cuh>
55
56
#include < raft/matrix/triangular.cuh>
63
64
#include < raft/util/cudart_utils.hpp>
64
65
65
66
#include < cuda.h>
67
+ #include < thrust/sort.h>
66
68
67
69
#include < cublasLt.h>
68
70
#include < curand.h>
@@ -1507,10 +1509,15 @@ void lanczos_solve_ritz(
1507
1509
raft::device_matrix_view<ValueTypeT, uint32_t , raft::row_major> beta,
1508
1510
std::optional<raft::device_vector_view<ValueTypeT, uint32_t >> beta_k,
1509
1511
IndexTypeT k,
1510
- int which,
1512
+ LANCZOS_WHICH which,
1511
1513
int ncv,
1512
1514
raft::device_matrix_view<ValueTypeT, uint32_t , raft::col_major> eigenvectors,
1513
- raft::device_vector_view<ValueTypeT> eigenvalues)
1515
+ raft::device_vector_view<ValueTypeT> eigenvalues,
1516
+ raft::device_matrix_view<ValueTypeT, uint32_t , raft::col_major>& eigenvectors_k,
1517
+ raft::device_vector_view<ValueTypeT, uint32_t >& eigenvalues_k,
1518
+ raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>& eigenvectors_k_slice,
1519
+ raft::device_vector_view<ValueTypeT> sm_eigenvalues,
1520
+ raft::device_matrix_view<ValueTypeT, uint32_t , raft::col_major> sm_eigenvectors)
1514
1521
{
1515
1522
auto stream = resource::get_cuda_stream (handle);
1516
1523
@@ -1543,6 +1550,75 @@ void lanczos_solve_ritz(
1543
1550
triangular_matrix.data_handle (), ncv, ncv);
1544
1551
1545
1552
raft::linalg::eig_dc (handle, triangular_matrix_view, eigenvectors, eigenvalues);
1553
+
1554
+ IndexTypeT nEigVecs = k;
1555
+
1556
+ auto indices = raft::make_device_vector<int >(handle, ncv);
1557
+ auto selected_indices = raft::make_device_vector<int >(handle, nEigVecs);
1558
+
1559
+ if (which == LANCZOS_WHICH::SA) {
1560
+ eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
1561
+ eigenvectors.data_handle (), ncv, nEigVecs);
1562
+ eigenvalues_k =
1563
+ raft::make_device_vector_view<ValueTypeT, uint32_t >(eigenvalues.data_handle (), nEigVecs);
1564
+ eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1565
+ eigenvectors.data_handle (), ncv, nEigVecs);
1566
+ } else if (which == LANCZOS_WHICH::LA) {
1567
+ eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
1568
+ eigenvectors.data_handle () + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
1569
+ eigenvalues_k = raft::make_device_vector_view<ValueTypeT, uint32_t >(
1570
+ eigenvalues.data_handle () + (ncv - nEigVecs), nEigVecs);
1571
+ eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1572
+ eigenvectors.data_handle () + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
1573
+ } else if (which == LANCZOS_WHICH::SM || which == LANCZOS_WHICH::LM) {
1574
+ thrust::sequence (thrust::device, indices.data_handle (), indices.data_handle () + ncv, 0 );
1575
+
1576
+ // Sort indices by absolute eigenvalues (magnitude) using a custom comparator
1577
+ thrust::sort (thrust::device,
1578
+ indices.data_handle (),
1579
+ indices.data_handle () + ncv,
1580
+ [eigenvalues = eigenvalues.data_handle ()] __device__ (int a, int b) {
1581
+ return fabsf (eigenvalues[a]) < fabsf (eigenvalues[b]);
1582
+ });
1583
+
1584
+ if (which == LANCZOS_WHICH::SM) {
1585
+ // Take the first nEigVecs indices (smallest magnitude)
1586
+ raft::copy (selected_indices.data_handle (), indices.data_handle (), nEigVecs, stream);
1587
+ } else if (which == LANCZOS_WHICH::LM) {
1588
+ // Take the last nEigVecs indices (largest magnitude)
1589
+ raft::copy (
1590
+ selected_indices.data_handle (), indices.data_handle () + (ncv - nEigVecs), nEigVecs, stream);
1591
+ }
1592
+
1593
+ // Re-sort these indices by algebraic value to maintain algebraic ordering
1594
+ thrust::sort (thrust::device,
1595
+ selected_indices.data_handle (),
1596
+ selected_indices.data_handle () + nEigVecs,
1597
+ [eigenvalues = eigenvalues.data_handle ()] __device__ (int a, int b) {
1598
+ return eigenvalues[a] < eigenvalues[b];
1599
+ });
1600
+ raft::matrix::gather (
1601
+ handle,
1602
+ raft::make_device_matrix_view<const ValueTypeT, uint32_t , raft::row_major>(
1603
+ eigenvalues.data_handle (), ncv, 1 ),
1604
+ raft::make_device_vector_view<const int , uint32_t >(selected_indices.data_handle (), nEigVecs),
1605
+ raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::row_major>(
1606
+ sm_eigenvalues.data_handle (), nEigVecs, 1 ));
1607
+ raft::matrix::gather (
1608
+ handle,
1609
+ raft::make_device_matrix_view<const ValueTypeT, uint32_t , raft::row_major>(
1610
+ eigenvectors.data_handle (), ncv, ncv),
1611
+ raft::make_device_vector_view<const int , uint32_t >(selected_indices.data_handle (), nEigVecs),
1612
+ raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::row_major>(
1613
+ sm_eigenvectors.data_handle (), nEigVecs, ncv));
1614
+
1615
+ eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
1616
+ sm_eigenvectors.data_handle (), ncv, nEigVecs);
1617
+ eigenvalues_k =
1618
+ raft::make_device_vector_view<ValueTypeT, uint32_t >(sm_eigenvalues.data_handle (), nEigVecs);
1619
+ eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1620
+ sm_eigenvectors.data_handle (), ncv, nEigVecs);
1621
+ }
1546
1622
}
1547
1623
1548
1624
template <typename IndexTypeT, typename ValueTypeT>
@@ -1695,6 +1771,7 @@ auto lanczos_smallest(
1695
1771
int maxIter,
1696
1772
int restartIter,
1697
1773
ValueTypeT tol,
1774
+ LANCZOS_WHICH which,
1698
1775
ValueTypeT* eigVals_dev,
1699
1776
ValueTypeT* eigVecs_dev,
1700
1777
ValueTypeT* v0,
@@ -1752,20 +1829,28 @@ auto lanczos_smallest(
1752
1829
raft::make_device_matrix<ValueTypeT, uint32_t , raft::col_major>(handle, ncv, ncv);
1753
1830
auto eigenvalues = raft::make_device_vector<ValueTypeT, uint32_t >(handle, ncv);
1754
1831
1832
+ raft::device_matrix_view<ValueTypeT, uint32_t , raft::col_major> eigenvectors_k;
1833
+ raft::device_vector_view<ValueTypeT, uint32_t > eigenvalues_k;
1834
+ raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major> eigenvectors_k_slice;
1835
+
1836
+ auto sm_eigenvalues = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
1837
+ auto sm_eigenvectors =
1838
+ raft::make_device_matrix<ValueTypeT, uint32_t , raft::col_major>(handle, ncv, nEigVecs);
1839
+
1755
1840
lanczos_solve_ritz<IndexTypeT, ValueTypeT>(handle,
1756
1841
alpha.view (),
1757
1842
beta.view (),
1758
1843
std::nullopt,
1759
1844
nEigVecs,
1760
- 0 ,
1845
+ which ,
1761
1846
ncv,
1762
1847
eigenvectors.view (),
1763
- eigenvalues.view ());
1764
-
1765
- auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
1766
- eigenvectors. data_handle (), ncv, nEigVecs);
1767
- auto eigenvalues_k =
1768
- raft::make_device_vector_view<ValueTypeT, uint32_t >(eigenvalues. data_handle (), nEigVecs );
1848
+ eigenvalues.view (),
1849
+ eigenvectors_k,
1850
+ eigenvalues_k,
1851
+ eigenvectors_k_slice,
1852
+ sm_eigenvalues. view (),
1853
+ sm_eigenvectors. view () );
1769
1854
1770
1855
auto ritz_eigenvectors =
1771
1856
raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(eigVecs_dev, n, nEigVecs);
@@ -1777,9 +1862,6 @@ auto lanczos_smallest(
1777
1862
1778
1863
auto s = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
1779
1864
1780
- auto eigenvectors_k_slice =
1781
- raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1782
- eigenvectors.data_handle (), ncv, nEigVecs);
1783
1865
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
1784
1866
s.data_handle (), 1 , nEigVecs);
1785
1867
@@ -2002,12 +2084,15 @@ auto lanczos_smallest(
2002
2084
beta.view (),
2003
2085
beta_k.view (),
2004
2086
nEigVecs,
2005
- 0 ,
2087
+ which ,
2006
2088
ncv,
2007
2089
eigenvectors.view (),
2008
- eigenvalues.view ());
2009
- auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
2010
- eigenvectors.data_handle (), ncv, nEigVecs);
2090
+ eigenvalues.view (),
2091
+ eigenvectors_k,
2092
+ eigenvalues_k,
2093
+ eigenvectors_k_slice,
2094
+ sm_eigenvalues.view (),
2095
+ sm_eigenvectors.view ());
2011
2096
2012
2097
auto ritz_eigenvectors = raft::make_device_matrix_view<ValueTypeT, uint32_t , raft::col_major>(
2013
2098
eigVecs_dev, n, nEigVecs);
@@ -2017,9 +2102,6 @@ auto lanczos_smallest(
2017
2102
raft::linalg::gemm<ValueTypeT, uint32_t , raft::col_major, raft::col_major, raft::col_major>(
2018
2103
handle, V_T, eigenvectors_k, ritz_eigenvectors);
2019
2104
2020
- auto eigenvectors_k_slice =
2021
- raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
2022
- eigenvectors.data_handle (), ncv, nEigVecs);
2023
2105
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
2024
2106
s.data_handle (), 1 , nEigVecs);
2025
2107
@@ -2066,6 +2148,7 @@ auto lanczos_compute_smallest_eigenvectors(
2066
2148
config.max_iterations ,
2067
2149
config.ncv ,
2068
2150
config.tolerance ,
2151
+ config.which ,
2069
2152
eigenvalues.data_handle (),
2070
2153
eigenvectors.data_handle (),
2071
2154
v0->data_handle (),
@@ -2082,6 +2165,7 @@ auto lanczos_compute_smallest_eigenvectors(
2082
2165
config.max_iterations ,
2083
2166
config.ncv ,
2084
2167
config.tolerance ,
2168
+ config.which ,
2085
2169
eigenvalues.data_handle (),
2086
2170
eigenvectors.data_handle (),
2087
2171
temp_v0.data_handle (),
0 commit comments