diff --git a/src/array_api_typing/__init__.py b/src/array_api_typing/__init__.py index 3532743..f677768 100644 --- a/src/array_api_typing/__init__.py +++ b/src/array_api_typing/__init__.py @@ -1,10 +1,18 @@ """Static typing support for the array API standard.""" __all__ = ( + "Array", + "ArrayNamespace", + "DType", + "Device", "HasArrayNamespace", "__version__", "__version_tuple__", + "signature_types", ) -from ._namespace import HasArrayNamespace +from . import signature_types +from ._array import Array +from ._misc_objects import Device, DType +from ._namespace import ArrayNamespace, HasArrayNamespace from ._version import version as __version__, version_tuple as __version_tuple__ diff --git a/src/array_api_typing/_array.py b/src/array_api_typing/_array.py new file mode 100644 index 0000000..3677ae6 --- /dev/null +++ b/src/array_api_typing/_array.py @@ -0,0 +1,12 @@ +"""Static typing support for array API arrays.""" + +from typing import Protocol + +from ._namespace import HasArrayNamespace + + +class Array(HasArrayNamespace, Protocol): + """An Array API array of homogenously-typed numbers.""" + + # TODO(https://github.com/data-apis/array-api-typing/issues/23): Populate this + # protocol with methods defined by the Array API specification. diff --git a/src/array_api_typing/_misc_objects.py b/src/array_api_typing/_misc_objects.py new file mode 100644 index 0000000..061acc9 --- /dev/null +++ b/src/array_api_typing/_misc_objects.py @@ -0,0 +1,6 @@ +"""Static typing support for miscellaneous objects in the array API.""" + +from typing import TypeAlias + +Device: TypeAlias = object # The device on which an Array API array is stored. +DType: TypeAlias = object # The type of the numbers contained in an Array API array.""" diff --git a/src/array_api_typing/_namespace.py b/src/array_api_typing/_namespace.py index 98099d1..0f31e61 100644 --- a/src/array_api_typing/_namespace.py +++ b/src/array_api_typing/_namespace.py @@ -1,13 +1,54 @@ __all__ = ("HasArrayNamespace",) -from types import ModuleType -from typing import Protocol, final +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol from typing_extensions import TypeVar -T = TypeVar("T", bound=object, default=ModuleType) # PEP 696 default +if TYPE_CHECKING: + # This condition exists to prevent a circular import: _array imports _namespace for + # HasArrayNamespace. Therefore, _namespace cannot import _array except when + # type-checking. The type variable depends on Array, so we create a dummy type + # variable without the same bounds and default for this case. In Python 3.13, this + # is no longer be necessary. + from typing_extensions import Buffer + + from ._array import Array + from ._misc_objects import Device, DType + from .signature_types import NestedSequence + + A = TypeVar("A", bound=Array, default=Array) # PEP 696 default +else: + A = TypeVar("A") + + +class ArrayNamespace(Protocol[A]): + """An Array API namespace.""" + + def asarray( + self, + obj: Array | complex | NestedSequence[complex] | Buffer, + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + ) -> A: ... + + def astype( + self, + x: A, + dtype: DType, + /, + *, + copy: bool = True, + device: Device | None = None, + ) -> A: ... + + +T = TypeVar("T", bound=ArrayNamespace, default=ArrayNamespace) # PEP 696 default -@final class HasArrayNamespace(Protocol[T]): # type: ignore[misc] # see python/mypy#17288 """Protocol for classes that have an `__array_namespace__` method. diff --git a/src/array_api_typing/signature_types/__init__.py b/src/array_api_typing/signature_types/__init__.py new file mode 100644 index 0000000..52fff5a --- /dev/null +++ b/src/array_api_typing/signature_types/__init__.py @@ -0,0 +1,7 @@ +"""Types that appear in public function signatures.""" + +__all__ = [ + "NestedSequence", +] + +from ._signature_types import NestedSequence diff --git a/src/array_api_typing/signature_types/_signature_types.py b/src/array_api_typing/signature_types/_signature_types.py new file mode 100644 index 0000000..a6c91bc --- /dev/null +++ b/src/array_api_typing/signature_types/_signature_types.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Iterator + +_T_co = TypeVar("_T_co", covariant=True) + + +@runtime_checkable +class NestedSequence(Protocol[_T_co]): + """A protocol for representing nested sequences. + + Warning: + ------- + `NestedSequence` currently does not work in combination with type variables, + *e.g.* ``def func(a: NestedSequnce[T]) -> T: ...``. + + See Also: + -------- + collections.abc.Sequence: + ABCs for read-only and mutable :term:`sequences`. + + Examples: + -------- + .. code-block:: python + + >>> from typing import TYPE_CHECKING + >>> import numpy as np + >>> import array_api_typing as xpt + + >>> def get_dtype(seq: xpt.NestedSequence[float]) -> np.dtype[np.float64]: + ... return np.asarray(seq).dtype + + >>> a = get_dtype([1.0]) + >>> b = get_dtype([[1.0]]) + >>> c = get_dtype([[[1.0]]]) + >>> d = get_dtype([[[[1.0]]]]) + + >>> if TYPE_CHECKING: + ... reveal_locals() + ... # note: Revealed local types are: + ... # note: a: numpy.dtype[numpy.floating[numpy._typing._64Bit]] + ... # note: b: numpy.dtype[numpy.floating[numpy._typing._64Bit]] + ... # note: c: numpy.dtype[numpy.floating[numpy._typing._64Bit]] + ... # note: d: numpy.dtype[numpy.floating[numpy._typing._64Bit]] + + """ + + def __len__(self, /) -> int: + """Implement ``len(self)``.""" + raise NotImplementedError + + def __getitem__(self, index: int, /) -> _T_co | NestedSequence[_T_co]: + """Implement ``self[x]``.""" + raise NotImplementedError + + def __contains__(self, x: object, /) -> bool: + """Implement ``x in self``.""" + raise NotImplementedError + + def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: + """Implement ``iter(self)``.""" + raise NotImplementedError + + def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: + """Implement ``reversed(self)``.""" + raise NotImplementedError + + def count(self, value: object, /) -> int: + """Return the number of occurrences of `value`.""" + raise NotImplementedError + + def index(self, value: object, /) -> int: + """Return the first index of `value`.""" + raise NotImplementedError