Skip to content

Commit a8bcf78

Browse files
authored
feat: Add mmap support for reading sparse vectors to avoid OOM error in CI (qdrant#129)
* fix: Manual benchmarks * fix: Remove gcs secrets * feat: Use mmap to read sparse vectors * fix: Format * fix: Make unused var private
1 parent 04bbb7c commit a8bcf78

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

dataset_reader/sparse_reader.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,28 @@ def read_sparse_matrix_fields(
2323
return values, columns, index_pointer
2424

2525

26+
def mmap_sparse_matrix_fields(fname):
27+
"""mmap the fields of a CSR matrix without instantiating it"""
28+
with open(fname, "rb") as f:
29+
sizes = np.fromfile(f, dtype="int64", count=3)
30+
n_row, _n_col, n_non_zero = sizes
31+
offset = sizes.nbytes
32+
index_pointer = np.memmap(
33+
fname, dtype="int64", mode="r", offset=offset, shape=n_row + 1
34+
)
35+
offset += index_pointer.nbytes
36+
columns = np.memmap(fname, dtype="int32", mode="r", offset=offset, shape=n_non_zero)
37+
offset += columns.nbytes
38+
values = np.memmap(
39+
fname, dtype="float32", mode="r", offset=offset, shape=n_non_zero
40+
)
41+
return values, columns, index_pointer
42+
43+
2644
def csr_to_sparse_vectors(
2745
values: List[float], columns: List[int], index_pointer: List[int]
2846
) -> Iterator[SparseVector]:
47+
"""Convert a CSR matrix to a list of SparseVectors"""
2948
num_rows = len(index_pointer) - 1
3049

3150
for i in range(num_rows):
@@ -38,9 +57,12 @@ def csr_to_sparse_vectors(
3857
yield SparseVector(indices=row_indices, values=row_values)
3958

4059

41-
def read_csr_matrix(filename: Union[Path, str]) -> Iterator[SparseVector]:
60+
def read_csr_matrix(filename: Union[Path, str], do_mmap=True) -> Iterator[SparseVector]:
4261
"""Read a CSR matrix in spmat format"""
43-
values, columns, index_pointer = read_sparse_matrix_fields(filename)
62+
if do_mmap:
63+
values, columns, index_pointer = mmap_sparse_matrix_fields(filename)
64+
else:
65+
values, columns, index_pointer = read_sparse_matrix_fields(filename)
4466
values = values.tolist()
4567
columns = columns.tolist()
4668
index_pointer = index_pointer.tolist()

0 commit comments

Comments
 (0)