Skip to content

Commit 305e5a4

Browse files
committed
review feedback
1 parent 11c9c22 commit 305e5a4

2 files changed

Lines changed: 109 additions & 79 deletions

File tree

python/pyarrow-stubs/pyarrow/_stubs_typing.pyi

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import datetime as dt
1919

20-
from collections.abc import Collection, Iterator, Sequence
20+
from collections.abc import Collection, Container, Iterator, Sequence, Sized
2121
from decimal import Decimal
2222
from typing import Any, Literal, Protocol, TypeAlias, TypeVar
2323

@@ -90,49 +90,65 @@ _V = TypeVar("_V", covariant=True)
9090

9191
SingleOrList: TypeAlias = list[_T] | _T
9292

93-
class SupportEq(Protocol):
93+
94+
class SupportsEq(Protocol):
9495
def __eq__(self, other: object, /) -> bool: ...
9596

96-
class SupportLt(Protocol):
97+
98+
class SupportsLt(Protocol):
9799
def __lt__(self, other: object, /) -> bool: ...
98100

99-
class SupportGt(Protocol):
101+
102+
class SupportsGt(Protocol):
100103
def __gt__(self, other: object, /) -> bool: ...
101104

102-
class SupportLe(Protocol):
105+
106+
class SupportsLe(Protocol):
103107
def __le__(self, other: object, /) -> bool: ...
104108

105-
class SupportGe(Protocol):
109+
110+
class SupportsGe(Protocol):
106111
def __ge__(self, other: object, /) -> bool: ...
107112

113+
108114
FilterTuple: TypeAlias = (
109-
tuple[str, Literal["=", "==", "!="], SupportEq]
110-
| tuple[str, Literal["<"], SupportLt]
111-
| tuple[str, Literal[">"], SupportGt]
112-
| tuple[str, Literal["<="], SupportLe]
113-
| tuple[str, Literal[">="], SupportGe]
115+
tuple[str, Literal["=", "==", "!="], SupportsEq]
116+
| tuple[str, Literal["<"], SupportsLt]
117+
| tuple[str, Literal[">"], SupportsGt]
118+
| tuple[str, Literal["<="], SupportsLe]
119+
| tuple[str, Literal[">="], SupportsGe]
114120
| tuple[str, Literal["in", "not in"], Collection]
115121
| tuple[str, str, Any] # Allow general str for operator to avoid type errors
116122
)
117123

124+
118125
class Buffer(Protocol): ...
119-
class SupportPyBuffer(Protocol): ...
120126

121-
class SupportArrowStream(Protocol):
127+
128+
class SupportsPyBuffer(Protocol): ...
129+
130+
131+
class SupportsArrowStream(Protocol):
122132
def __arrow_c_stream__(self, requested_schema=None, /) -> Any: ...
123133

124-
class SupportPyArrowArray(Protocol):
134+
135+
class SupportsPyArrowArray(Protocol):
125136
def __arrow_array__(self, type=None, /) -> Any: ...
126137

127-
class SupportArrowArray(Protocol):
138+
139+
class SupportsArrowArray(Protocol):
128140
def __arrow_c_array__(self, requested_schema=None, /) -> Any: ...
129141

130-
class SupportArrowDeviceArray(Protocol):
142+
143+
class SupportsArrowDeviceArray(Protocol):
131144
def __arrow_c_device_array__(self, requested_schema=None, /, **kwargs) -> Any: ...
132145

133-
class SupportArrowSchema(Protocol):
146+
147+
class SupportsArrowSchema(Protocol):
134148
def __arrow_c_schema__(self) -> Any: ...
135149

136-
from collections.abc import Container, Sized
137-
class NullableCollection(Sized, Container[_V], Protocol[_V]):
138-
def __iter__(self) -> Iterator[_V] | Iterator[_V | None]: ...
150+
151+
class NullableCollection(Sized, Container[_V], Protocol[_V]):
152+
def __iter__(self) -> Iterator[_V] | Iterator[_V | None]: ...
153+
def __len__(self) -> int: ...
154+
def __contains__(self, item: Any, /) -> bool: ...

python/pyarrow-stubs/pyarrow/_types.pyi

Lines changed: 73 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
import datetime as dt # noqa: F401
1919

20-
from collections.abc import Mapping, Sequence, Iterable, Iterator
20+
from collections.abc import Iterable, Iterator, Mapping, Sequence
2121
from decimal import Decimal # noqa: F401
22-
from typing import Any, Generic, Literal
22+
from typing import Any, Generic, Literal, Protocol, TypeAlias
2323

2424
import numpy as np
2525
import pandas as pd
2626

2727
from typing_extensions import Self, TypeVar, deprecated
2828

29-
from pyarrow._stubs_typing import SupportArrowSchema, TimeUnit
29+
from pyarrow._stubs_typing import SupportsArrowSchema, TimeUnit
3030
from pyarrow.io import Buffer
3131
from pyarrow.lib import ( # noqa: F401
3232
Array,
@@ -72,8 +72,8 @@ class DataType(_Weakrefable):
7272
def _import_from_c_capsule(cls, schema) -> Self: ...
7373

7474
_AsPyType = TypeVar("_AsPyType")
75-
_DataTypeT = TypeVar("_DataTypeT", bound=DataType)
76-
_DataTypeT_co = TypeVar("_DataTypeT", bound=DataType, covariant=True)
75+
_DataTypeT = TypeVar("_DataTypeT", bound=DataType)
76+
_DataTypeT_co = TypeVar("_DataTypeT_co", bound=DataType, covariant=True)
7777

7878
class _BasicDataType(DataType, Generic[_AsPyType]): ...
7979
class NullType(_BasicDataType[None]): ...
@@ -128,54 +128,60 @@ _FixedSizeBinaryAsPyType = TypeVar("_FixedSizeBinaryAsPyType", default=bytes)
128128

129129
class FixedSizeBinaryType(_BasicDataType[_FixedSizeBinaryAsPyType]): ...
130130

131-
from typing import Protocol
132-
133-
_Precision = TypeVar("_Precision", default=Any , covariant=True)
134-
_Scale = TypeVar("_Scale", default=Any , covariant=True)
135-
136-
class _HasPrecisionScale(Protocol[_Precision, _Scale]):
137-
@property
138-
def precision(self) -> _Precision: ...
139-
@property
140-
def scale(self) -> _Scale: ...
141-
142-
class Decimal32Type(FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]): ...
143-
144-
class Decimal64Type(FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]): ...
145-
146-
class Decimal128Type(FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]): ...
147-
148-
class Decimal256Type(FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]): ...
131+
_Precision = TypeVar("_Precision", default=Any, covariant=True)
132+
_Scale = TypeVar("_Scale", default=Any, covariant=True)
149133

150-
class ListType(DataType, Generic[_DataTypeT]):
134+
class _HasPrecisionScale(Protocol[_Precision, _Scale]):
151135
@property
152-
def value_field(self) -> Field[_DataTypeT]: ...
136+
def precision(self) -> _Precision: ...
153137
@property
154-
def value_type(self) -> _DataTypeT: ...
138+
def scale(self) -> _Scale: ...
155139

156-
class LargeListType(DataType, Generic[_DataTypeT]):
140+
class Decimal32Type(
141+
FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]
142+
): ...
143+
144+
class Decimal64Type(
145+
FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]
146+
): ...
147+
148+
class Decimal128Type(
149+
FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]
150+
): ...
151+
152+
class Decimal256Type(
153+
FixedSizeBinaryType[Decimal], _HasPrecisionScale[_Precision, _Scale]
154+
): ...
155+
156+
class ListType(DataType, Generic[_DataTypeT_co]):
157157
@property
158-
def value_field(self) -> Field[_DataTypeT]: ...
158+
def value_field(self) -> Field[_DataTypeT_co]: ...
159159
@property
160-
def value_type(self) -> _DataTypeT: ...
160+
def value_type(self) -> _DataTypeT_co: ...
161161

162-
class ListViewType(DataType, Generic[_DataTypeT]):
162+
class LargeListType(DataType, Generic[_DataTypeT_co]):
163163
@property
164-
def value_field(self) -> Field[_DataTypeT]: ...
164+
def value_field(self) -> Field[_DataTypeT_co]: ...
165165
@property
166-
def value_type(self) -> _DataTypeT: ...
166+
def value_type(self) -> _DataTypeT_co: ...
167167

168-
class LargeListViewType(DataType, Generic[_DataTypeT]):
168+
class ListViewType(DataType, Generic[_DataTypeT_co]):
169169
@property
170-
def value_field(self) -> Field[_DataTypeT]: ...
170+
def value_field(self) -> Field[_DataTypeT_co]: ...
171171
@property
172-
def value_type(self) -> _DataTypeT: ...
172+
def value_type(self) -> _DataTypeT_co: ...
173173

174-
class FixedSizeListType(DataType, Generic[_DataTypeT, _Size]):
174+
class LargeListViewType(DataType, Generic[_DataTypeT_co]):
175175
@property
176-
def value_field(self) -> Field[_DataTypeT]: ...
176+
def value_field(self) -> Field[_DataTypeT_co]: ...
177177
@property
178-
def value_type(self) -> _DataTypeT: ...
178+
def value_type(self) -> _DataTypeT_co: ...
179+
180+
class FixedSizeListType(DataType, Generic[_DataTypeT_co, _Size]):
181+
@property
182+
def value_field(self) -> Field[_DataTypeT_co]: ...
183+
@property
184+
def value_type(self) -> _DataTypeT_co: ...
179185
@property
180186
def list_size(self) -> int: ...
181187

@@ -305,13 +311,22 @@ class UnknownExtensionType(ExtensionType):
305311
def register_extension_type(ext_type: ExtensionType) -> None: ...
306312
def unregister_extension_type(type_name: str) -> None: ...
307313

314+
_StrOrBytes: TypeAlias = str | bytes
315+
_MetadataMapping: TypeAlias = Mapping[_StrOrBytes, _StrOrBytes]
316+
_MetadataIterable: TypeAlias = Iterable[tuple[_StrOrBytes, _StrOrBytes]]
317+
_KeyValueMetadataInput: TypeAlias = _MetadataMapping | _MetadataIterable | None
318+
_FieldTypeInput: TypeAlias = DataType | str | None
319+
_SchemaMetadataInput: TypeAlias = (
320+
Mapping[bytes, bytes]
321+
| Mapping[str, str]
322+
| Mapping[bytes, str]
323+
| Mapping[str, bytes]
324+
)
325+
308326
class KeyValueMetadata(_Metadata, Mapping[bytes, bytes]):
309327
def __init__(
310328
self,
311-
__arg0__: Mapping[str | bytes, str | bytes]
312-
| Iterable[tuple[str | bytes, str | bytes]]
313-
| KeyValueMetadata
314-
| None = None,
329+
__arg0__: _KeyValueMetadataInput | KeyValueMetadata = None,
315330
**kwargs: str,
316331
) -> None: ...
317332
def equals(self, other: KeyValueMetadata) -> bool: ...
@@ -322,7 +337,7 @@ class KeyValueMetadata(_Metadata, Mapping[bytes, bytes]):
322337
def get_all(self, key: str) -> list[bytes]: ...
323338
def to_dict(self) -> dict[bytes, bytes]: ...
324339

325-
class Field(_Weakrefable, Generic[_DataTypeT]):
340+
class Field(_Weakrefable, Generic[_DataTypeT_co]):
326341
def equals(self, other: Field, check_metadata: bool = False) -> bool: ...
327342
def __hash__(self) -> int: ...
328343
@property
@@ -332,17 +347,15 @@ class Field(_Weakrefable, Generic[_DataTypeT]):
332347
@property
333348
def metadata(self) -> dict[bytes, bytes] | None: ...
334349
@property
335-
def type(self) -> _DataTypeT: ...
350+
def type(self) -> _DataTypeT_co: ...
336351
def with_metadata(
337352
self,
338-
metadata: dict[bytes | str, bytes | str]
339-
| Mapping[bytes | str, bytes | str]
340-
| Any,
353+
metadata: _MetadataMapping | Any,
341354
) -> Self: ...
342355
def remove_metadata(self) -> Self: ...
343356
def with_type(self, new_type: DataType) -> Field: ...
344357
def with_name(self, name: str) -> Self: ...
345-
def with_nullable(self, nullable: bool) -> Field[_DataTypeT]: ...
358+
def with_nullable(self, nullable: bool) -> Field[_DataTypeT_co]: ...
346359
def flatten(self) -> list[Field]: ...
347360
def _export_to_c(self, out_ptr: int) -> None: ...
348361
@classmethod
@@ -351,6 +364,14 @@ class Field(_Weakrefable, Generic[_DataTypeT]):
351364
@classmethod
352365
def _import_from_c_capsule(cls, schema) -> Self: ...
353366

367+
_StructFieldTuple: TypeAlias = (
368+
tuple[str, Field[Any] | None] | tuple[str, _FieldTypeInput]
369+
)
370+
_StructFieldsInput: TypeAlias = (
371+
Iterable[Field[Any] | _StructFieldTuple]
372+
| Mapping[str, Field[Any] | DataType | str | None]
373+
)
374+
354375
class Schema(_Weakrefable):
355376
def __len__(self) -> int: ...
356377
def __getitem__(self, key: str | int) -> Field: ...
@@ -407,10 +428,10 @@ def unify_schemas(
407428
promote_options: Literal["default", "permissive"] = "default",
408429
) -> Schema: ...
409430
def field(
410-
name: SupportArrowSchema | str | Any,
431+
name: SupportsArrowSchema | str | bytes,
411432
type: _DataTypeT | str | None = None,
412433
nullable: bool = True,
413-
metadata: dict[Any, Any] | None = None,
434+
metadata: _MetadataMapping | None = None,
414435
) -> Field[_DataTypeT] | Field[Any]: ...
415436
def null() -> NullType: ...
416437
def bool_() -> BoolType: ...
@@ -484,10 +505,7 @@ def dictionary(
484505
ordered: _Ordered | None = None,
485506
) -> DictionaryType[_IndexT, _BasicValueT, _Ordered]: ...
486507
def struct(
487-
fields: Iterable[
488-
Field[Any] | tuple[str, Field[Any] | None] | tuple[str, DataType | None]
489-
]
490-
| Mapping[str, Field[Any] | DataType | None],
508+
fields: _StructFieldsInput,
491509
) -> StructType: ...
492510
def sparse_union(
493511
child_fields: list[Field[Any]], type_codes: list[int] | None = None
@@ -520,11 +538,7 @@ def schema(
520538
| Iterable[tuple[str, DataType | str | None]]
521539
| Mapping[Any, DataType | str | None]
522540
),
523-
metadata: Mapping[bytes, bytes]
524-
| Mapping[str, str]
525-
| Mapping[bytes, str]
526-
| Mapping[str, bytes]
527-
| None = None,
541+
metadata: _SchemaMetadataInput | None = None,
528542
) -> Schema: ...
529543
def from_numpy_dtype(dtype: np.dtype[Any] | type | str) -> DataType: ...
530544

0 commit comments

Comments
 (0)