Skip to content

Commit 086bda8

Browse files
authored
fix(spm): improved dataset type checking (#107)
* better type checking * additional related type fixes * additional logic and consistency fixes
1 parent 15a653d commit 086bda8

File tree

5 files changed

+110
-67
lines changed

5 files changed

+110
-67
lines changed

cryosparc/column.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __new__(cls, field: Field, data: Data):
5252
dtype = n.dtype(fielddtype(field))
5353
nrow = data.nrow()
5454
shape = (nrow, *dtype.shape)
55-
buffer = data.getbuf(field[0]).memview if nrow else None
55+
buffer = data.getbuf(field[0])
56+
if buffer is not None:
57+
buffer = buffer.memview
5658
obj = super().__new__(cls, shape=shape, dtype=dtype.base, buffer=buffer) # type: ignore
5759

5860
# Keep a reference to the data so that it only gets cleaned up when all

cryosparc/core.pyi

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from enum import Enum
2+
from typing import SupportsBytes
3+
4+
from numpy.typing import NDArray
5+
6+
__all__ = ["DsetType", "Stream", "Data"]
7+
8+
class MemoryView(SupportsBytes): # Note: Supports buffer protocol.
9+
base: "Array"
10+
size: int
11+
itemsize: int
12+
nbytes: int
13+
ndim: int
14+
shape: tuple[int, ...]
15+
strides: tuple[int, ...]
16+
suboffsets: tuple[int, ...]
17+
T: "MemoryView"
18+
19+
def copy(self) -> "MemoryView": ...
20+
def copy_fortran(self) -> "MemoryView": ...
21+
def is_c_contig(self) -> bool: ...
22+
def is_f_contig(self) -> bool: ...
23+
24+
class Array:
25+
memview: MemoryView
26+
27+
def __len__(self) -> int: ...
28+
def __getitem__(self, key: int | slice) -> bytes: ...
29+
def __setitem__(self, key: int | slice, item: bytes): ...
30+
31+
class DsetType(int, Enum):
32+
T_F32 = ...
33+
T_F64 = ...
34+
T_C32 = ...
35+
T_C64 = ...
36+
T_I8 = ...
37+
T_I16 = ...
38+
T_I32 = ...
39+
T_I64 = ...
40+
T_U8 = ...
41+
T_U16 = ...
42+
T_U32 = ...
43+
T_U64 = ...
44+
T_STR = ...
45+
T_OBJ = ...
46+
47+
class Data:
48+
def __init__(self, other: "Data" | None = None) -> None: ...
49+
def innerjoin(self, key: str, other: "Data") -> "Data": ...
50+
def totalsz(self) -> int: ...
51+
def ncol(self) -> int: ...
52+
def nrow(self) -> int: ...
53+
def key(self, index: int) -> str: ...
54+
def type(self, field: str) -> int: ...
55+
def has(self, field: str) -> bool: ...
56+
def addrows(self, num: int) -> None: ...
57+
def addcol_scalar(self, field: str, dtype: int) -> None: ...
58+
def addcol_array(self, field: str, dtype: int, shape: tuple[int, ...]) -> None: ...
59+
def getshp(self, colkey: str) -> tuple[int, ...]: ...
60+
def getbuf(self, colkey: str) -> Array | None: ...
61+
def getstr(self, col: str, index: int) -> bytes: ...
62+
def tocstrs(self, col: str) -> bool: ...
63+
def topystrs(self, col: str) -> bool: ...
64+
def stralloc(self, val: str) -> int: ...
65+
def dump(self) -> Array: ...
66+
def dumpstrheap(self) -> Array: ...
67+
def setstrheap(self, heap: bytes) -> None: ...
68+
def defrag(self, realloc_smaller: bool) -> None: ...
69+
def dumptxt(self, dump_data: bool = False) -> None: ...
70+
def handle(self) -> int: ...
71+
72+
class Stream:
73+
def __init__(self, data: Data) -> None: ...
74+
def cast_objs_to_strs(self) -> None: ...
75+
def stralloc_col(self, col: str) -> Array | None: ...
76+
def compress_col(self, col: str) -> Array: ...
77+
def compress_numpy(self, arr: NDArray) -> Array: ...
78+
def compress(self, arr: Array) -> Array: ...
79+
def decompress_col(self, col: str, data: bytes) -> Array: ...
80+
def decompress_numpy(self, data: bytes, arr: NDArray) -> Array: ...
81+
def decompress(self, data: bytes, outptr: int = 0) -> Array: ...

cryosparc/core.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ cdef class Data:
142142
with nogil:
143143
mem = dataset.dset_get(self._handle, colkey_c)
144144
size = dataset.dset_getsz(self._handle, colkey_c)
145-
return 0 if size == 0 else <unsigned char [:size]> mem
145+
return None if size == 0 else <unsigned char [:size]> mem
146146

147147
def getstr(self, str col, size_t index):
148148
return dataset.dset_getstr(self._handle, col.encode(), index) # returns bytes

cryosparc/dataset.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@
7575
if TYPE_CHECKING:
7676
from numpy.typing import ArrayLike, DTypeLike, NDArray
7777

78-
from .core import MemoryView
79-
8078

8179
# Save format options
8280
NUMPY_FORMAT = 1
@@ -683,16 +681,17 @@ def _load_stream(
683681
descr = filter_descr(header["dtype"], keep_prefixes=prefixes, keep_fields=fields)
684682
field_names = {field[0] for field in descr}
685683

686-
# Calling addrows separately to minimizes column-based
687-
# allocations, improves performance by ~20%
684+
# Calling addrows separately to minimize column-based allocations,
685+
# improves performance by ~20%
688686
dset = cls.allocate(0, descr)
689-
if header["length"] == 0:
690-
return dset # no more data to load
691-
692687
data = dset._data
693688
data.addrows(header["length"])
689+
690+
# If a dataset is empty, it won't have anything in its data section.
691+
# Just the string heap at the end.
692+
dtype = [] if header["length"] == 0 else header["dtype"]
694693
loader = Stream(data)
695-
for field in header["dtype"]:
694+
for field in dtype:
696695
colsize = u32intle(f.read(4))
697696
if field[0] not in field_names:
698697
# try to seek instead of read to reduce memory usage
@@ -701,8 +700,10 @@ def _load_stream(
701700
buffer = f.read(colsize)
702701
if field[0] in header["compressed_fields"]:
703702
loader.decompress_col(field[0], buffer)
704-
else:
705-
data.getbuf(field[0])[:] = buffer
703+
continue
704+
mem = data.getbuf(field[0])
705+
assert mem is not None, f"Could not load stream (missing {field[0]} buffer)"
706+
mem[:] = buffer
706707

707708
# Read in the string heap (rest of stream)
708709
# NOTE: There will be a bug here for long column keys that are
@@ -726,16 +727,22 @@ async def from_async_stream(cls, stream: AsyncBinaryIO):
726727
dset = cls.allocate(0, header["dtype"])
727728
data = dset._data
728729
data.addrows(header["length"])
730+
731+
# If a dataset is empty, it won't have anything in its data section.
732+
# Just the string heap at the end.
733+
dtype = [] if header["length"] == 0 else header["dtype"]
729734
loader = Stream(data)
730-
for field in header["dtype"]:
735+
for field in dtype:
731736
colsize = u32intle(await stream.read(4))
732737
buffer = await stream.read(colsize)
733738
if field[0] in header["compressed_fields"]:
734739
loader.decompress_col(field[0], buffer)
735-
else:
736-
data.getbuf(field[0])[:] = buffer
740+
continue
741+
mem = data.getbuf(field[0])
742+
assert mem is not None, f"Could not load stream (missing {field[0]} buffer)"
743+
mem[:] = buffer
737744

738-
heap = stream.read()
745+
heap = await stream.read()
739746
data.setstrheap(heap)
740747

741748
# Convert C strings to Python strings
@@ -803,16 +810,14 @@ def stream(self, compression: Literal["lz4", None] = None) -> Generator[bytes, N
803810
yield u32bytesle(len(header))
804811
yield header
805812

806-
if len(self) == 0:
807-
return # empty dataset, don't yield anything
808-
809-
for f in self:
810-
fielddata: "MemoryView"
813+
fields = [] if len(self) == 0 else self.fields()
814+
for f in fields:
811815
if f in compressed_fields:
812816
# obj columns added to strheap and loaded as indexes
813817
fielddata = stream.compress_col(f)
814818
else:
815819
fielddata = stream.stralloc_col(f) or data.getbuf(f)
820+
assert fielddata is not None, f"Could not stream dataset (missing {f} buffer)"
816821
yield u32bytesle(len(fielddata))
817822
yield bytes(fielddata.memview)
818823

@@ -1231,7 +1236,7 @@ def filter_prefix(self, keep_prefix: str, *, rename: Optional[str] = None, copy:
12311236
if rename and rename != keep_prefix:
12321237
new_fields = [f"{rename}/{f.split('/', 1)[1]}" for f in keep_fields]
12331238

1234-
result = type(self)([("uid", self["uid"])] + [(nf, self[f]) for f, nf in zip(keep_fields, new_fields)])
1239+
result = type(self)([("uid", self["uid"])] + [(nf, self[f]) for f, nf in zip(keep_fields, new_fields)]) # type: ignore
12351240
return result if copy else self._reset(result._data)
12361241

12371242
def drop_fields(self, names: Union[Collection[str], Callable[[str], bool]], *, copy: bool = False):

typings/cryosparc/core.pyi

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)