@@ -23,9 +23,28 @@ def read_sparse_matrix_fields(
23
23
return values , columns , index_pointer
24
24
25
25
26
+ def mmap_sparse_matrix_fields (fname ):
27
+ """mmap the fields of a CSR matrix without instantiating it"""
28
+ with open (fname , "rb" ) as f :
29
+ sizes = np .fromfile (f , dtype = "int64" , count = 3 )
30
+ n_row , _n_col , n_non_zero = sizes
31
+ offset = sizes .nbytes
32
+ index_pointer = np .memmap (
33
+ fname , dtype = "int64" , mode = "r" , offset = offset , shape = n_row + 1
34
+ )
35
+ offset += index_pointer .nbytes
36
+ columns = np .memmap (fname , dtype = "int32" , mode = "r" , offset = offset , shape = n_non_zero )
37
+ offset += columns .nbytes
38
+ values = np .memmap (
39
+ fname , dtype = "float32" , mode = "r" , offset = offset , shape = n_non_zero
40
+ )
41
+ return values , columns , index_pointer
42
+
43
+
26
44
def csr_to_sparse_vectors (
27
45
values : List [float ], columns : List [int ], index_pointer : List [int ]
28
46
) -> Iterator [SparseVector ]:
47
+ """Convert a CSR matrix to a list of SparseVectors"""
29
48
num_rows = len (index_pointer ) - 1
30
49
31
50
for i in range (num_rows ):
@@ -38,9 +57,12 @@ def csr_to_sparse_vectors(
38
57
yield SparseVector (indices = row_indices , values = row_values )
39
58
40
59
41
- def read_csr_matrix (filename : Union [Path , str ]) -> Iterator [SparseVector ]:
60
+ def read_csr_matrix (filename : Union [Path , str ], do_mmap = True ) -> Iterator [SparseVector ]:
42
61
"""Read a CSR matrix in spmat format"""
43
- values , columns , index_pointer = read_sparse_matrix_fields (filename )
62
+ if do_mmap :
63
+ values , columns , index_pointer = mmap_sparse_matrix_fields (filename )
64
+ else :
65
+ values , columns , index_pointer = read_sparse_matrix_fields (filename )
44
66
values = values .tolist ()
45
67
columns = columns .tolist ()
46
68
index_pointer = index_pointer .tolist ()
0 commit comments