@@ -4,8 +4,9 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri
4
4
5
5
import numpy as np
6
6
from _numtype import Array , ToGeneric_0d , ToGeneric_1nd , ToGeneric_nd
7
- from numpy import _OrderACF , _OrderKACF , amax , amin , bool_ , expand_dims # noqa: ICN003
8
- from numpy ._typing import _BoolCodes
7
+ from numpy import _OrderACF , _OrderKACF , amax , amin , bool_ , expand_dims , intp # noqa: ICN003
8
+ from numpy ._globals import _NoValueType
9
+ from numpy ._typing import _BoolCodes , _ScalarLike_co
9
10
10
11
__all__ = [
11
12
"MAError" ,
@@ -188,6 +189,12 @@ __all__ = [
188
189
"zeros_like" ,
189
190
]
190
191
192
+ ###
193
+
194
+ _ArrayT = TypeVar ("_ArrayT" , bound = np .ndarray [Any , Any ])
195
+
196
+ ###
197
+
191
198
_UFuncT_co = TypeVar ("_UFuncT_co" , bound = np .ufunc , default = np .ufunc , covariant = True )
192
199
_ShapeT = TypeVar ("_ShapeT" , bound = tuple [int , ...])
193
200
_ShapeT_co = TypeVar ("_ShapeT_co" , bound = tuple [int , ...], default = tuple [int , ...], covariant = True )
@@ -650,15 +657,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
650
657
fill_value : Incomplete = ...,
651
658
keepdims : Incomplete = ...,
652
659
) -> Incomplete : ...
653
- @override
654
- def argmin ( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
660
+
661
+ # Keep in-sync with np.ma.argmin
662
+ @overload # type: ignore[override]
663
+ def argmin (
655
664
self ,
656
- axis : Incomplete = ... ,
657
- fill_value : Incomplete = ... ,
658
- out : Incomplete = ... ,
665
+ axis : None = None ,
666
+ fill_value : _ScalarLike_co | None = None ,
667
+ out : None = None ,
659
668
* ,
660
- keepdims : Incomplete = ...,
661
- ) -> Incomplete : ...
669
+ keepdims : L [False ] | _NoValueType = ...,
670
+ ) -> intp : ...
671
+ @overload
672
+ def argmin (
673
+ self ,
674
+ axis : CanIndex | None = None ,
675
+ fill_value : _ScalarLike_co | None = None ,
676
+ out : None = None ,
677
+ * ,
678
+ keepdims : bool | _NoValueType = ...,
679
+ ) -> Any : ...
680
+ @overload
681
+ def argmin (
682
+ self ,
683
+ axis : CanIndex | None = None ,
684
+ fill_value : _ScalarLike_co | None = None ,
685
+ * ,
686
+ out : _ArrayT ,
687
+ keepdims : bool | _NoValueType = ...,
688
+ ) -> _ArrayT : ...
689
+ @overload
690
+ def argmin ( # pyright: ignore[reportIncompatibleMethodOverride]
691
+ self ,
692
+ axis : CanIndex | None ,
693
+ fill_value : _ScalarLike_co | None ,
694
+ out : _ArrayT ,
695
+ * ,
696
+ keepdims : bool | _NoValueType = ...,
697
+ ) -> _ArrayT : ...
662
698
663
699
#
664
700
@override
@@ -669,15 +705,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
669
705
fill_value : Incomplete = ...,
670
706
keepdims : Incomplete = ...,
671
707
) -> Incomplete : ...
672
- @override
673
- def argmax ( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
708
+
709
+ # Keep in-sync with np.ma.argmax
710
+ @overload # type: ignore[override]
711
+ def argmax (
674
712
self ,
675
- axis : Incomplete = ... ,
676
- fill_value : Incomplete = ... ,
677
- out : Incomplete = ... ,
713
+ axis : None = None ,
714
+ fill_value : _ScalarLike_co | None = None ,
715
+ out : None = None ,
678
716
* ,
679
- keepdims : Incomplete = ...,
680
- ) -> Incomplete : ...
717
+ keepdims : L [False ] | _NoValueType = ...,
718
+ ) -> intp : ...
719
+ @overload
720
+ def argmax (
721
+ self ,
722
+ axis : CanIndex | None = None ,
723
+ fill_value : _ScalarLike_co | None = None ,
724
+ out : None = None ,
725
+ * ,
726
+ keepdims : bool | _NoValueType = ...,
727
+ ) -> Any : ...
728
+ @overload
729
+ def argmax (
730
+ self ,
731
+ axis : CanIndex | None = None ,
732
+ fill_value : _ScalarLike_co | None = None ,
733
+ * ,
734
+ out : _ArrayT ,
735
+ keepdims : bool | _NoValueType = ...,
736
+ ) -> _ArrayT : ...
737
+ @overload
738
+ def argmax ( # pyright: ignore[reportIncompatibleMethodOverride]
739
+ self ,
740
+ axis : CanIndex | None ,
741
+ fill_value : _ScalarLike_co | None ,
742
+ out : _ArrayT ,
743
+ * ,
744
+ keepdims : bool | _NoValueType = ...,
745
+ ) -> _ArrayT : ...
681
746
682
747
#
683
748
@override
@@ -1066,8 +1131,81 @@ swapaxes: _frommethod
1066
1131
trace : _frommethod
1067
1132
var : _frommethod
1068
1133
count : _frommethod
1069
- argmin : _frommethod
1070
- argmax : _frommethod
1071
-
1072
1134
minimum : _extrema_operation
1073
1135
maximum : _extrema_operation
1136
+
1137
+ #
1138
+ @overload
1139
+ def argmin (
1140
+ a : _ArrayT , # pyright: ignore[reportInvalidTypeVarUse]
1141
+ axis : None = None ,
1142
+ fill_value : _ScalarLike_co | None = None ,
1143
+ out : None = None ,
1144
+ * ,
1145
+ keepdims : L [False ] | _NoValueType = ...,
1146
+ ) -> intp : ...
1147
+ @overload
1148
+ def argmin (
1149
+ a : _ArrayT , # pyright: ignore[reportInvalidTypeVarUse]
1150
+ axis : CanIndex | None = None ,
1151
+ fill_value : _ScalarLike_co | None = None ,
1152
+ out : None = None ,
1153
+ * ,
1154
+ keepdims : bool | _NoValueType = ...,
1155
+ ) -> Any : ...
1156
+ @overload
1157
+ def argmin (
1158
+ a : _ArrayT ,
1159
+ axis : CanIndex | None = None ,
1160
+ fill_value : _ScalarLike_co | None = None ,
1161
+ * ,
1162
+ out : _ArrayT ,
1163
+ keepdims : bool | _NoValueType = ...,
1164
+ ) -> _ArrayT : ...
1165
+ @overload
1166
+ def argmin (
1167
+ a : _ArrayT ,
1168
+ axis : CanIndex | None ,
1169
+ fill_value : _ScalarLike_co | None ,
1170
+ out : _ArrayT ,
1171
+ * ,
1172
+ keepdims : bool | _NoValueType = ...,
1173
+ ) -> _ArrayT : ...
1174
+
1175
+ #
1176
+ @overload
1177
+ def argmax (
1178
+ a : _ArrayT , # pyright: ignore[reportInvalidTypeVarUse]
1179
+ axis : None = None ,
1180
+ fill_value : _ScalarLike_co | None = None ,
1181
+ out : None = None ,
1182
+ * ,
1183
+ keepdims : L [False ] | _NoValueType = ...,
1184
+ ) -> intp : ...
1185
+ @overload
1186
+ def argmax (
1187
+ a : _ArrayT , # pyright: ignore[reportInvalidTypeVarUse]
1188
+ axis : CanIndex | None = None ,
1189
+ fill_value : _ScalarLike_co | None = None ,
1190
+ out : None = None ,
1191
+ * ,
1192
+ keepdims : bool | _NoValueType = ...,
1193
+ ) -> Any : ...
1194
+ @overload
1195
+ def argmax (
1196
+ a : _ArrayT ,
1197
+ axis : CanIndex | None = None ,
1198
+ fill_value : _ScalarLike_co | None = None ,
1199
+ * ,
1200
+ out : _ArrayT ,
1201
+ keepdims : bool | _NoValueType = ...,
1202
+ ) -> _ArrayT : ...
1203
+ @overload
1204
+ def argmax (
1205
+ a : _ArrayT ,
1206
+ axis : CanIndex | None ,
1207
+ fill_value : _ScalarLike_co | None ,
1208
+ out : _ArrayT ,
1209
+ * ,
1210
+ keepdims : bool | _NoValueType = ...,
1211
+ ) -> _ArrayT : ...
0 commit comments