Skip to content

Commit 54319e0

Browse files
committed
🚚 port ma.arg{min,max} and MaskedArray.arg{min,max}
1 parent b37e58b commit 54319e0

File tree

3 files changed

+231
-19
lines changed

3 files changed

+231
-19
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import math
2+
from typing import Any, TypeAlias, TypeVar
3+
from typing_extensions import assert_type
4+
5+
import numpy as np
6+
from numpy._typing import _Shape
7+
8+
_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True)
9+
MaskedNDArray: TypeAlias = np.ma.MaskedArray[_Shape, np.dtype[_ScalarType_co]]
10+
11+
class MaskedNDArraySubclass(MaskedNDArray[np.complex128]): ...
12+
13+
MAR_b: MaskedNDArray[np.bool]
14+
MAR_f4: MaskedNDArray[np.float32]
15+
MAR_i8: MaskedNDArray[np.int64]
16+
MAR_subclass: MaskedNDArraySubclass
17+
MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]]
18+
19+
assert_type(MAR_b.argmin(), np.intp)
20+
assert_type(MAR_f4.argmin(), np.intp)
21+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
22+
assert_type(MAR_b.argmin(axis=0), Any)
23+
assert_type(MAR_f4.argmin(axis=0), Any)
24+
assert_type(MAR_b.argmin(keepdims=True), Any)
25+
assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass)
26+
assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass)
27+
28+
assert_type(np.ma.argmin(MAR_b), np.intp)
29+
assert_type(np.ma.argmin(MAR_f4), np.intp)
30+
assert_type(np.ma.argmin(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
31+
assert_type(np.ma.argmin(MAR_b, axis=0), Any)
32+
assert_type(np.ma.argmin(MAR_f4, axis=0), Any)
33+
assert_type(np.ma.argmin(MAR_b, keepdims=True), Any)
34+
35+
assert_type(MAR_b.argmax(), np.intp)
36+
assert_type(MAR_f4.argmax(), np.intp)
37+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
38+
assert_type(MAR_b.argmax(axis=0), Any)
39+
assert_type(MAR_f4.argmax(axis=0), Any)
40+
assert_type(MAR_b.argmax(keepdims=True), Any)
41+
assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass)
42+
assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass)
43+
44+
assert_type(np.ma.argmax(MAR_b), np.intp)
45+
assert_type(np.ma.argmax(MAR_f4), np.intp)
46+
assert_type(np.ma.argmax(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
47+
assert_type(np.ma.argmax(MAR_b, axis=0), Any)
48+
assert_type(np.ma.argmax(MAR_f4, axis=0), Any)
49+
assert_type(np.ma.argmax(MAR_b, keepdims=True), Any)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
m: np.ma.MaskedArray[tuple[int], np.dtype[np.float64]]
4+
5+
m.argmin(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
6+
m.argmin(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
7+
m.argmin(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
8+
m.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
9+
10+
np.ma.argmin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
11+
np.ma.argmin(m, axis=(1,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
12+
np.ma.argmin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
13+
np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
14+
np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
15+
16+
m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
17+
m.argmax(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
18+
m.argmax(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
19+
m.argmax(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
20+
21+
np.ma.argmax(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
22+
np.ma.argmax(m, axis=(0,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
23+
np.ma.argmax(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
24+
np.ma.argmax(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
25+
np.ma.argmax(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]

‎src/numpy-stubs/ma/core.pyi

Lines changed: 157 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri
44

55
import numpy as np
66
from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd
7-
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003
8-
from numpy._typing import _BoolCodes
7+
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003
8+
from numpy._globals import _NoValueType
9+
from numpy._typing import _BoolCodes, _ScalarLike_co
910

1011
__all__ = [
1112
"MAError",
@@ -188,6 +189,12 @@ __all__ = [
188189
"zeros_like",
189190
]
190191

192+
###
193+
194+
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
195+
196+
###
197+
191198
_UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True)
192199
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
193200
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], default=tuple[int, ...], covariant=True)
@@ -650,15 +657,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
650657
fill_value: Incomplete = ...,
651658
keepdims: Incomplete = ...,
652659
) -> Incomplete: ...
653-
@override
654-
def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
660+
661+
# Keep in-sync with np.ma.argmin
662+
@overload # type: ignore[override]
663+
def argmin(
655664
self,
656-
axis: Incomplete = ...,
657-
fill_value: Incomplete = ...,
658-
out: Incomplete = ...,
665+
axis: None = None,
666+
fill_value: _ScalarLike_co | None = None,
667+
out: None = None,
659668
*,
660-
keepdims: Incomplete = ...,
661-
) -> Incomplete: ...
669+
keepdims: L[False] | _NoValueType = ...,
670+
) -> intp: ...
671+
@overload
672+
def argmin(
673+
self,
674+
axis: CanIndex | None = None,
675+
fill_value: _ScalarLike_co | None = None,
676+
out: None = None,
677+
*,
678+
keepdims: bool | _NoValueType = ...,
679+
) -> Any: ...
680+
@overload
681+
def argmin(
682+
self,
683+
axis: CanIndex | None = None,
684+
fill_value: _ScalarLike_co | None = None,
685+
*,
686+
out: _ArrayT,
687+
keepdims: bool | _NoValueType = ...,
688+
) -> _ArrayT: ...
689+
@overload
690+
def argmin( # pyright: ignore[reportIncompatibleMethodOverride]
691+
self,
692+
axis: CanIndex | None,
693+
fill_value: _ScalarLike_co | None,
694+
out: _ArrayT,
695+
*,
696+
keepdims: bool | _NoValueType = ...,
697+
) -> _ArrayT: ...
662698

663699
#
664700
@override
@@ -669,15 +705,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
669705
fill_value: Incomplete = ...,
670706
keepdims: Incomplete = ...,
671707
) -> Incomplete: ...
672-
@override
673-
def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
708+
709+
# Keep in-sync with np.ma.argmax
710+
@overload # type: ignore[override]
711+
def argmax(
674712
self,
675-
axis: Incomplete = ...,
676-
fill_value: Incomplete = ...,
677-
out: Incomplete = ...,
713+
axis: None = None,
714+
fill_value: _ScalarLike_co | None = None,
715+
out: None = None,
678716
*,
679-
keepdims: Incomplete = ...,
680-
) -> Incomplete: ...
717+
keepdims: L[False] | _NoValueType = ...,
718+
) -> intp: ...
719+
@overload
720+
def argmax(
721+
self,
722+
axis: CanIndex | None = None,
723+
fill_value: _ScalarLike_co | None = None,
724+
out: None = None,
725+
*,
726+
keepdims: bool | _NoValueType = ...,
727+
) -> Any: ...
728+
@overload
729+
def argmax(
730+
self,
731+
axis: CanIndex | None = None,
732+
fill_value: _ScalarLike_co | None = None,
733+
*,
734+
out: _ArrayT,
735+
keepdims: bool | _NoValueType = ...,
736+
) -> _ArrayT: ...
737+
@overload
738+
def argmax( # pyright: ignore[reportIncompatibleMethodOverride]
739+
self,
740+
axis: CanIndex | None,
741+
fill_value: _ScalarLike_co | None,
742+
out: _ArrayT,
743+
*,
744+
keepdims: bool | _NoValueType = ...,
745+
) -> _ArrayT: ...
681746

682747
#
683748
@override
@@ -1066,8 +1131,81 @@ swapaxes: _frommethod
10661131
trace: _frommethod
10671132
var: _frommethod
10681133
count: _frommethod
1069-
argmin: _frommethod
1070-
argmax: _frommethod
1071-
10721134
minimum: _extrema_operation
10731135
maximum: _extrema_operation
1136+
1137+
#
1138+
@overload
1139+
def argmin(
1140+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1141+
axis: None = None,
1142+
fill_value: _ScalarLike_co | None = None,
1143+
out: None = None,
1144+
*,
1145+
keepdims: L[False] | _NoValueType = ...,
1146+
) -> intp: ...
1147+
@overload
1148+
def argmin(
1149+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1150+
axis: CanIndex | None = None,
1151+
fill_value: _ScalarLike_co | None = None,
1152+
out: None = None,
1153+
*,
1154+
keepdims: bool | _NoValueType = ...,
1155+
) -> Any: ...
1156+
@overload
1157+
def argmin(
1158+
a: _ArrayT,
1159+
axis: CanIndex | None = None,
1160+
fill_value: _ScalarLike_co | None = None,
1161+
*,
1162+
out: _ArrayT,
1163+
keepdims: bool | _NoValueType = ...,
1164+
) -> _ArrayT: ...
1165+
@overload
1166+
def argmin(
1167+
a: _ArrayT,
1168+
axis: CanIndex | None,
1169+
fill_value: _ScalarLike_co | None,
1170+
out: _ArrayT,
1171+
*,
1172+
keepdims: bool | _NoValueType = ...,
1173+
) -> _ArrayT: ...
1174+
1175+
#
1176+
@overload
1177+
def argmax(
1178+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1179+
axis: None = None,
1180+
fill_value: _ScalarLike_co | None = None,
1181+
out: None = None,
1182+
*,
1183+
keepdims: L[False] | _NoValueType = ...,
1184+
) -> intp: ...
1185+
@overload
1186+
def argmax(
1187+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1188+
axis: CanIndex | None = None,
1189+
fill_value: _ScalarLike_co | None = None,
1190+
out: None = None,
1191+
*,
1192+
keepdims: bool | _NoValueType = ...,
1193+
) -> Any: ...
1194+
@overload
1195+
def argmax(
1196+
a: _ArrayT,
1197+
axis: CanIndex | None = None,
1198+
fill_value: _ScalarLike_co | None = None,
1199+
*,
1200+
out: _ArrayT,
1201+
keepdims: bool | _NoValueType = ...,
1202+
) -> _ArrayT: ...
1203+
@overload
1204+
def argmax(
1205+
a: _ArrayT,
1206+
axis: CanIndex | None,
1207+
fill_value: _ScalarLike_co | None,
1208+
out: _ArrayT,
1209+
*,
1210+
keepdims: bool | _NoValueType = ...,
1211+
) -> _ArrayT: ...

0 commit comments

Comments
 (0)