Skip to content

Commit e2ca0e9

Browse files
committed
Revert "Infer correct types with overloads of Type[Guard | Is] (#17678)"
This reverts commit 43ea203. The commit caused a regression (#19139). If we can't fix the regression soon enough, reverting the original change temporarily will at least unblock the mypy public release. The reverted PR can be merged again once the regression is fixed.
1 parent 537fc55 commit e2ca0e9

File tree

4 files changed

+14
-268
lines changed

4 files changed

+14
-268
lines changed

mypy/checker.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6160,31 +6160,15 @@ def find_isinstance_check_helper(
61606160
# considered "always right" (i.e. even if the types are not overlapping).
61616161
# Also note that a care must be taken to unwrap this back at read places
61626162
# where we use this to narrow down declared type.
6163-
with self.msg.filter_errors(), self.local_type_map():
6164-
# `node.callee` can be an `overload`ed function,
6165-
# we need to resolve the real `overload` case.
6166-
_, real_func = self.expr_checker.check_call(
6167-
get_proper_type(self.lookup_type(node.callee)),
6168-
node.args,
6169-
node.arg_kinds,
6170-
node,
6171-
node.arg_names,
6172-
)
6173-
real_func = get_proper_type(real_func)
6174-
if not isinstance(real_func, CallableType) or not (
6175-
real_func.type_guard or real_func.type_is
6176-
):
6177-
return {}, {}
6178-
6179-
if real_func.type_guard is not None:
6180-
return {expr: TypeGuardedType(real_func.type_guard)}, {}
6163+
if node.callee.type_guard is not None:
6164+
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
61816165
else:
6182-
assert real_func.type_is is not None
6166+
assert node.callee.type_is is not None
61836167
return conditional_types_to_typemaps(
61846168
expr,
61856169
*self.conditional_types_with_intersection(
61866170
self.lookup_type(expr),
6187-
[TypeRange(real_func.type_is, is_upper_bound=False)],
6171+
[TypeRange(node.callee.type_is, is_upper_bound=False)],
61886172
expr,
61896173
),
61906174
)

mypy/checkexpr.py

Lines changed: 10 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2925,37 +2925,16 @@ def infer_overload_return_type(
29252925
elif all_same_types([erase_type(typ) for typ in return_types]):
29262926
self.chk.store_types(type_maps[0])
29272927
return erase_type(return_types[0]), erase_type(inferred_types[0])
2928-
return self.check_call(
2929-
callee=AnyType(TypeOfAny.special_form),
2930-
args=args,
2931-
arg_kinds=arg_kinds,
2932-
arg_names=arg_names,
2933-
context=context,
2934-
callable_name=callable_name,
2935-
object_type=object_type,
2936-
)
2937-
elif not all_same_type_narrowers(matches):
2938-
# This is an example of how overloads can be:
2939-
#
2940-
# @overload
2941-
# def is_int(obj: float) -> TypeGuard[float]: ...
2942-
# @overload
2943-
# def is_int(obj: int) -> TypeGuard[int]: ...
2944-
#
2945-
# x: Any
2946-
# if is_int(x):
2947-
# reveal_type(x) # N: int | float
2948-
#
2949-
# So, we need to check that special case.
2950-
return self.check_call(
2951-
callee=self.combine_function_signatures(cast("list[ProperType]", matches)),
2952-
args=args,
2953-
arg_kinds=arg_kinds,
2954-
arg_names=arg_names,
2955-
context=context,
2956-
callable_name=callable_name,
2957-
object_type=object_type,
2958-
)
2928+
else:
2929+
return self.check_call(
2930+
callee=AnyType(TypeOfAny.special_form),
2931+
args=args,
2932+
arg_kinds=arg_kinds,
2933+
arg_names=arg_names,
2934+
context=context,
2935+
callable_name=callable_name,
2936+
object_type=object_type,
2937+
)
29592938
else:
29602939
# Success! No ambiguity; return the first match.
29612940
self.chk.store_types(type_maps[0])
@@ -3170,8 +3149,6 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31703149
new_args: list[list[Type]] = [[] for _ in range(len(callables[0].arg_types))]
31713150
new_kinds = list(callables[0].arg_kinds)
31723151
new_returns: list[Type] = []
3173-
new_type_guards: list[Type] = []
3174-
new_type_narrowers: list[Type] = []
31753152

31763153
too_complex = False
31773154
for target in callables:
@@ -3198,25 +3175,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
31983175
for i, arg in enumerate(target.arg_types):
31993176
new_args[i].append(arg)
32003177
new_returns.append(target.ret_type)
3201-
if target.type_guard:
3202-
new_type_guards.append(target.type_guard)
3203-
if target.type_is:
3204-
new_type_narrowers.append(target.type_is)
3205-
3206-
if new_type_guards and new_type_narrowers:
3207-
# They cannot be defined at the same time,
3208-
# declaring this function as too complex!
3209-
too_complex = True
3210-
union_type_guard = None
3211-
union_type_is = None
3212-
else:
3213-
union_type_guard = make_simplified_union(new_type_guards) if new_type_guards else None
3214-
union_type_is = (
3215-
make_simplified_union(new_type_narrowers) if new_type_narrowers else None
3216-
)
32173178

32183179
union_return = make_simplified_union(new_returns)
3219-
32203180
if too_complex:
32213181
any = AnyType(TypeOfAny.special_form)
32223182
return callables[0].copy_modified(
@@ -3226,8 +3186,6 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
32263186
ret_type=union_return,
32273187
variables=variables,
32283188
implicit=True,
3229-
type_guard=union_type_guard,
3230-
type_is=union_type_is,
32313189
)
32323190

32333191
final_args = []
@@ -3241,8 +3199,6 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
32413199
ret_type=union_return,
32423200
variables=variables,
32433201
implicit=True,
3244-
type_guard=union_type_guard,
3245-
type_is=union_type_is,
32463202
)
32473203

32483204
def erased_signature_similarity(
@@ -6599,25 +6555,6 @@ def all_same_types(types: list[Type]) -> bool:
65996555
return all(is_same_type(t, types[0]) for t in types[1:])
66006556

66016557

6602-
def all_same_type_narrowers(types: list[CallableType]) -> bool:
6603-
if len(types) <= 1:
6604-
return True
6605-
6606-
type_guards: list[Type] = []
6607-
type_narrowers: list[Type] = []
6608-
6609-
for typ in types:
6610-
if typ.type_guard:
6611-
type_guards.append(typ.type_guard)
6612-
if typ.type_is:
6613-
type_narrowers.append(typ.type_is)
6614-
if type_guards and type_narrowers:
6615-
# Some overloads declare `TypeGuard` and some declare `TypeIs`,
6616-
# we cannot handle this in a union.
6617-
return False
6618-
return all_same_types(type_guards) and all_same_types(type_narrowers)
6619-
6620-
66216558
def merge_typevars_in_callables_by_name(
66226559
callables: Sequence[CallableType],
66236560
) -> tuple[list[CallableType], list[TypeVarType]]:

test-data/unit/check-typeguard.test

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -731,62 +731,6 @@ assert a(x=x)
731731
reveal_type(x) # N: Revealed type is "builtins.int"
732732
[builtins fixtures/tuple.pyi]
733733

734-
[case testTypeGuardInOverloads]
735-
from typing import Any, overload, Union
736-
from typing_extensions import TypeGuard
737-
738-
@overload
739-
def func1(x: str) -> TypeGuard[str]:
740-
...
741-
742-
@overload
743-
def func1(x: int) -> TypeGuard[int]:
744-
...
745-
746-
def func1(x: Any) -> Any:
747-
return True
748-
749-
def func2(val: Any):
750-
if func1(val):
751-
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
752-
else:
753-
reveal_type(val) # N: Revealed type is "Any"
754-
755-
def func3(val: Union[int, str]):
756-
if func1(val):
757-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
758-
else:
759-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
760-
761-
def func4(val: int):
762-
if func1(val):
763-
reveal_type(val) # N: Revealed type is "builtins.int"
764-
else:
765-
reveal_type(val) # N: Revealed type is "builtins.int"
766-
[builtins fixtures/tuple.pyi]
767-
768-
[case testTypeIsInOverloadsSameReturn]
769-
from typing import Any, overload, Union
770-
from typing_extensions import TypeGuard
771-
772-
@overload
773-
def func1(x: str) -> TypeGuard[str]:
774-
...
775-
776-
@overload
777-
def func1(x: int) -> TypeGuard[str]:
778-
...
779-
780-
def func1(x: Any) -> Any:
781-
return True
782-
783-
def func2(val: Union[int, str]):
784-
if func1(val):
785-
reveal_type(val) # N: Revealed type is "builtins.str"
786-
else:
787-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
788-
[builtins fixtures/tuple.pyi]
789-
790734
[case testTypeGuardRestrictAwaySingleInvariant]
791735
from typing import List
792736
from typing_extensions import TypeGuard

test-data/unit/check-typeis.test

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -818,125 +818,6 @@ accept_typeguard(typeguard)
818818

819819
[builtins fixtures/tuple.pyi]
820820

821-
[case testTypeIsInOverloads]
822-
from typing import Any, overload, Union
823-
from typing_extensions import TypeIs
824-
825-
@overload
826-
def func1(x: str) -> TypeIs[str]:
827-
...
828-
829-
@overload
830-
def func1(x: int) -> TypeIs[int]:
831-
...
832-
833-
def func1(x: Any) -> Any:
834-
return True
835-
836-
def func2(val: Any):
837-
if func1(val):
838-
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
839-
else:
840-
reveal_type(val) # N: Revealed type is "Any"
841-
842-
def func3(val: Union[int, str]):
843-
if func1(val):
844-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
845-
else:
846-
reveal_type(val)
847-
848-
def func4(val: int):
849-
if func1(val):
850-
reveal_type(val) # N: Revealed type is "builtins.int"
851-
else:
852-
reveal_type(val)
853-
[builtins fixtures/tuple.pyi]
854-
855-
[case testTypeIsInOverloadsSameReturn]
856-
from typing import Any, overload, Union
857-
from typing_extensions import TypeIs
858-
859-
@overload
860-
def func1(x: str) -> TypeIs[str]:
861-
...
862-
863-
@overload
864-
def func1(x: int) -> TypeIs[str]: # type: ignore
865-
...
866-
867-
def func1(x: Any) -> Any:
868-
return True
869-
870-
def func2(val: Union[int, str]):
871-
if func1(val):
872-
reveal_type(val) # N: Revealed type is "builtins.str"
873-
else:
874-
reveal_type(val) # N: Revealed type is "builtins.int"
875-
[builtins fixtures/tuple.pyi]
876-
877-
[case testTypeIsInOverloadsUnionizeError]
878-
from typing import Any, overload, Union
879-
from typing_extensions import TypeIs, TypeGuard
880-
881-
@overload
882-
def func1(x: str) -> TypeIs[str]:
883-
...
884-
885-
@overload
886-
def func1(x: int) -> TypeGuard[int]:
887-
...
888-
889-
def func1(x: Any) -> Any:
890-
return True
891-
892-
def func2(val: Union[int, str]):
893-
if func1(val):
894-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
895-
else:
896-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
897-
[builtins fixtures/tuple.pyi]
898-
899-
[case testTypeIsInOverloadsUnionizeError2]
900-
from typing import Any, overload, Union
901-
from typing_extensions import TypeIs, TypeGuard
902-
903-
@overload
904-
def func1(x: int) -> TypeGuard[int]:
905-
...
906-
907-
@overload
908-
def func1(x: str) -> TypeIs[str]:
909-
...
910-
911-
def func1(x: Any) -> Any:
912-
return True
913-
914-
def func2(val: Union[int, str]):
915-
if func1(val):
916-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
917-
else:
918-
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
919-
[builtins fixtures/tuple.pyi]
920-
921-
[case testTypeIsLikeIsDataclass]
922-
from typing import Any, overload, Union, Type
923-
from typing_extensions import TypeIs
924-
925-
class DataclassInstance: ...
926-
927-
@overload
928-
def is_dataclass(obj: type) -> TypeIs[Type[DataclassInstance]]: ...
929-
@overload
930-
def is_dataclass(obj: object) -> TypeIs[Union[DataclassInstance, Type[DataclassInstance]]]: ...
931-
932-
def is_dataclass(obj: Union[type, object]) -> bool:
933-
return False
934-
935-
def func(arg: Any) -> None:
936-
if is_dataclass(arg):
937-
reveal_type(arg) # N: Revealed type is "Union[Type[__main__.DataclassInstance], __main__.DataclassInstance]"
938-
[builtins fixtures/tuple.pyi]
939-
940821
[case testTypeIsEnumOverlappingUnionExcludesIrrelevant]
941822
from enum import Enum
942823
from typing import Literal

0 commit comments

Comments
 (0)