7575if TYPE_CHECKING :
7676 from numpy .typing import ArrayLike , DTypeLike , NDArray
7777
78- from .core import MemoryView
79-
8078
8179# Save format options
8280NUMPY_FORMAT = 1
@@ -683,16 +681,17 @@ def _load_stream(
683681 descr = filter_descr (header ["dtype" ], keep_prefixes = prefixes , keep_fields = fields )
684682 field_names = {field [0 ] for field in descr }
685683
686- # Calling addrows separately to minimizes column-based
687- # allocations, improves performance by ~20%
684+ # Calling addrows separately to minimize column-based allocations,
685+ # improves performance by ~20%
688686 dset = cls .allocate (0 , descr )
689- if header ["length" ] == 0 :
690- return dset # no more data to load
691-
692687 data = dset ._data
693688 data .addrows (header ["length" ])
689+
690+ # If a dataset is empty, it won't have anything in its data section.
691+ # Just the string heap at the end.
692+ dtype = [] if header ["length" ] == 0 else header ["dtype" ]
694693 loader = Stream (data )
695- for field in header [ " dtype" ] :
694+ for field in dtype :
696695 colsize = u32intle (f .read (4 ))
697696 if field [0 ] not in field_names :
698697 # try to seek instead of read to reduce memory usage
@@ -701,8 +700,10 @@ def _load_stream(
701700 buffer = f .read (colsize )
702701 if field [0 ] in header ["compressed_fields" ]:
703702 loader .decompress_col (field [0 ], buffer )
704- else :
705- data .getbuf (field [0 ])[:] = buffer
703+ continue
704+ mem = data .getbuf (field [0 ])
705+ assert mem is not None , f"Could not load stream (missing { field [0 ]} buffer)"
706+ mem [:] = buffer
706707
707708 # Read in the string heap (rest of stream)
708709 # NOTE: There will be a bug here for long column keys that are
@@ -726,16 +727,22 @@ async def from_async_stream(cls, stream: AsyncBinaryIO):
726727 dset = cls .allocate (0 , header ["dtype" ])
727728 data = dset ._data
728729 data .addrows (header ["length" ])
730+
731+ # If a dataset is empty, it won't have anything in its data section.
732+ # Just the string heap at the end.
733+ dtype = [] if header ["length" ] == 0 else header ["dtype" ]
729734 loader = Stream (data )
730- for field in header [ " dtype" ] :
735+ for field in dtype :
731736 colsize = u32intle (await stream .read (4 ))
732737 buffer = await stream .read (colsize )
733738 if field [0 ] in header ["compressed_fields" ]:
734739 loader .decompress_col (field [0 ], buffer )
735- else :
736- data .getbuf (field [0 ])[:] = buffer
740+ continue
741+ mem = data .getbuf (field [0 ])
742+ assert mem is not None , f"Could not load stream (missing { field [0 ]} buffer)"
743+ mem [:] = buffer
737744
738- heap = stream .read ()
745+ heap = await stream .read ()
739746 data .setstrheap (heap )
740747
741748 # Convert C strings to Python strings
@@ -803,16 +810,14 @@ def stream(self, compression: Literal["lz4", None] = None) -> Generator[bytes, N
803810 yield u32bytesle (len (header ))
804811 yield header
805812
806- if len (self ) == 0 :
807- return # empty dataset, don't yield anything
808-
809- for f in self :
810- fielddata : "MemoryView"
813+ fields = [] if len (self ) == 0 else self .fields ()
814+ for f in fields :
811815 if f in compressed_fields :
812816 # obj columns added to strheap and loaded as indexes
813817 fielddata = stream .compress_col (f )
814818 else :
815819 fielddata = stream .stralloc_col (f ) or data .getbuf (f )
820+ assert fielddata is not None , f"Could not stream dataset (missing { f } buffer)"
816821 yield u32bytesle (len (fielddata ))
817822 yield bytes (fielddata .memview )
818823
@@ -1231,7 +1236,7 @@ def filter_prefix(self, keep_prefix: str, *, rename: Optional[str] = None, copy:
12311236 if rename and rename != keep_prefix :
12321237 new_fields = [f"{ rename } /{ f .split ('/' , 1 )[1 ]} " for f in keep_fields ]
12331238
1234- result = type (self )([("uid" , self ["uid" ])] + [(nf , self [f ]) for f , nf in zip (keep_fields , new_fields )])
1239+ result = type (self )([("uid" , self ["uid" ])] + [(nf , self [f ]) for f , nf in zip (keep_fields , new_fields )]) # type: ignore
12351240 return result if copy else self ._reset (result ._data )
12361241
12371242 def drop_fields (self , names : Union [Collection [str ], Callable [[str ], bool ]], * , copy : bool = False ):
0 commit comments