diff --git a/changes/3050.bugfix.rst b/changes/3050.bugfix.rst new file mode 100644 index 0000000000..c0b6678e78 --- /dev/null +++ b/changes/3050.bugfix.rst @@ -0,0 +1 @@ +- Fixed potential error in `AsyncGroup.create_dataset()` where `dtype` argument could be missing when calling `create_array()` diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 5c470e29ca..974044b6b6 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -47,7 +47,6 @@ NodeType, ShapeLike, ZarrFormat, - parse_shapelike, ) from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata @@ -441,9 +440,8 @@ class AsyncGroup: metadata: GroupMetadata store_path: StorePath - - # TODO: make this correct and work - # TODO: ensure that this can be bound properly to subclass of AsyncGroup + _sync: Any = field(default=None, init=False) + _async_group: Any = field(default=None, init=False) @classmethod async def from_store( @@ -991,6 +989,53 @@ async def require_groups(self, *names: str) -> tuple[AsyncGroup, ...]: return () return tuple(await asyncio.gather(*(self.require_group(name) for name in names))) + async def _require_array_async( + self, + name: str, + *, + shape: ShapeLike, + dtype: npt.DTypeLike = None, + exact: bool = False, + **kwargs: Any, + ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: + """Obtain an array, creating if it doesn't exist. + + Other `kwargs` are as per :func:`zarr.AsyncGroup.create_array`. + + Parameters + ---------- + name : str + Array name. + shape : int or tuple of ints + Array shape. + dtype : str or dtype, optional + NumPy dtype. If None, the dtype will be inferred from the existing array. + exact : bool, optional + If True, require `dtype` to match exactly. If False, require + `dtype` can be cast from array dtype. + + Returns + ------- + a : AsyncArray + """ + try: + item = await self.getitem(name) + except KeyError: + # If it doesn't exist, create it + return await self.create_array(name, shape=shape, dtype=dtype, **kwargs) + else: + # Existing item must be an AsyncArray with matching dtype/shape + if not isinstance(item, AsyncArray): + raise TypeError(f"Incompatible object ({item.__class__.__name__}) already exists") + assert isinstance(item, AsyncArray) # mypy + if exact and dtype is not None and item.dtype != np.dtype(dtype): + raise TypeError("Incompatible dtype") + if not exact and dtype is not None and not np.can_cast(item.dtype, dtype): + raise TypeError("Incompatible dtype") + if item.shape != shape: + raise TypeError("Incompatible shape") + return item + async def create_array( self, name: str, @@ -1155,32 +1200,35 @@ async def create_dataset( # create_dataset in zarr 2.x requires shape but not dtype if data is # provided. Allow this configuration by inferring dtype from data if # necessary and passing it to create_array - if "dtype" not in kwargs and data is not None: - kwargs["dtype"] = data.dtype + if "dtype" not in kwargs: + if data is not None: + kwargs["dtype"] = data.dtype + else: + raise ValueError("dtype must be provided if data is None") array = await self.create_array(name, shape=shape, **kwargs) if data is not None: await array.setitem(slice(None), data) return array - @deprecated("Use AsyncGroup.require_array instead.") - async def require_dataset( + @deprecated("Use Group.require_array instead.") + def require_dataset( self, name: str, *, - shape: ChunkCoords, + shape: ShapeLike, dtype: npt.DTypeLike = None, exact: bool = False, **kwargs: Any, - ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: + ) -> Array: """Obtain an array, creating if it doesn't exist. .. deprecated:: 3.0.0 - The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.require_dataset` instead. + The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead. Arrays are known as "datasets" in HDF5 terminology. For compatibility with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method. - Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`. + Other `kwargs` are as per :func:`zarr.Group.create_array`. Parameters ---------- @@ -1189,18 +1237,24 @@ async def require_dataset( shape : int or tuple of ints Array shape. dtype : str or dtype, optional - NumPy dtype. + NumPy dtype. If None, the dtype will be inferred from the existing array. exact : bool, optional - If True, require `dtype` to match exactly. If false, require + If True, require `dtype` to match exactly. If False, require `dtype` can be cast from array dtype. Returns ------- - a : AsyncArray + a : Array """ - return await self.require_array(name, shape=shape, dtype=dtype, exact=exact, **kwargs) + return Array( + self._sync( + self._async_group._require_array_async( + name, shape=shape, dtype=dtype, exact=exact, **kwargs + ) + ) + ) - async def require_array( + def require_array( self, name: str, *, @@ -1208,10 +1262,10 @@ async def require_array( dtype: npt.DTypeLike = None, exact: bool = False, **kwargs: Any, - ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: + ) -> Array: """Obtain an array, creating if it doesn't exist. - Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`. + Other `kwargs` are as per :func:`zarr.Group.create_array`. Parameters ---------- @@ -1220,35 +1274,22 @@ async def require_array( shape : int or tuple of ints Array shape. dtype : str or dtype, optional - NumPy dtype. + NumPy dtype. If None, the dtype will be inferred from the existing array. exact : bool, optional - If True, require `dtype` to match exactly. If false, require + If True, require `dtype` to match exactly. If False, require `dtype` can be cast from array dtype. Returns ------- - a : AsyncArray + a : Array """ - try: - ds = await self.getitem(name) - if not isinstance(ds, AsyncArray): - raise TypeError(f"Incompatible object ({ds.__class__.__name__}) already exists") - - shape = parse_shapelike(shape) - if shape != ds.shape: - raise TypeError(f"Incompatible shape ({ds.shape} vs {shape})") - - dtype = np.dtype(dtype) - if exact: - if ds.dtype != dtype: - raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})") - else: - if not np.can_cast(ds.dtype, dtype): - raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})") - except KeyError: - ds = await self.create_array(name, shape=shape, dtype=dtype, **kwargs) - - return ds + return Array( + self._sync( + self._async_group._require_array_async( + name, shape=shape, dtype=dtype, exact=exact, **kwargs + ) + ) + ) async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: """Update group attributes. @@ -2511,49 +2552,75 @@ def create_dataset(self, name: str, **kwargs: Any) -> Array: .. deprecated:: 3.0.0 The h5py compatibility methods will be removed in 3.1.0. Use `Group.create_array` instead. - Arrays are known as "datasets" in HDF5 terminology. For compatibility - with h5py, Zarr groups also implement the :func:`zarr.Group.require_dataset` method. + with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.require_dataset` method. Parameters ---------- name : str Array name. **kwargs : dict - Additional arguments passed to :func:`zarr.Group.create_array` + Additional arguments passed to :func:`zarr.AsyncGroup.create_array`. Returns ------- - a : Array + a : AsyncArray """ return Array(self._sync(self._async_group.create_dataset(name, **kwargs))) @deprecated("Use Group.require_array instead.") - def require_dataset(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array: + def require_dataset( + self, + name: str, + *, + shape: ShapeLike, + dtype: npt.DTypeLike = None, + exact: bool = False, + **kwargs: Any, + ) -> Array: """Obtain an array, creating if it doesn't exist. .. deprecated:: 3.0.0 The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead. Arrays are known as "datasets" in HDF5 terminology. For compatibility - with h5py, Zarr groups also implement the :func:`zarr.Group.create_dataset` method. + with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method. - Other `kwargs` are as per :func:`zarr.Group.create_dataset`. + Other `kwargs` are as per :func:`zarr.Group.create_array`. Parameters ---------- name : str Array name. - **kwargs : - See :func:`zarr.Group.create_dataset`. + shape : int or tuple of ints + Array shape. + dtype : str or dtype, optional + NumPy dtype. If None, the dtype will be inferred from the existing array. + exact : bool, optional + If True, require `dtype` to match exactly. If False, require + `dtype` can be cast from array dtype. Returns ------- a : Array """ - return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs))) + return Array( + self._sync( + self._async_group._require_array_async( + name, shape=shape, dtype=dtype, exact=exact, **kwargs + ) + ) + ) - def require_array(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array: + def require_array( + self, + name: str, + *, + shape: ShapeLike, + dtype: npt.DTypeLike = None, + exact: bool = False, + **kwargs: Any, + ) -> Array: """Obtain an array, creating if it doesn't exist. Other `kwargs` are as per :func:`zarr.Group.create_array`. @@ -2562,14 +2629,25 @@ def require_array(self, name: str, *, shape: ShapeLike, **kwargs: Any) -> Array: ---------- name : str Array name. - **kwargs : - See :func:`zarr.Group.create_array`. + shape : int or tuple of ints + Array shape. + dtype : str or dtype, optional + NumPy dtype. If None, the dtype will be inferred from the existing array. + exact : bool, optional + If True, require `dtype` to match exactly. If False, require + `dtype` can be cast from array dtype. Returns ------- a : Array """ - return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs))) + return Array( + self._sync( + self._async_group._require_array_async( + name, shape=shape, dtype=dtype, exact=exact, **kwargs + ) + ) + ) @_deprecate_positional_args def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array: @@ -2918,7 +2996,7 @@ async def create_hierarchy( This function will parse its input to ensure that the hierarchy is complete. Any implicit groups will be inserted as needed. For example, an input like ```{'a/b': GroupMetadata}``` will be parsed to - ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}``` + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```. After input parsing, this function then creates all the nodes in the hierarchy concurrently.