Skip to content

Refactor: compact and optimize infer_overload_return_type while preserving behavior and comments #19198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 46 additions & 48 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,69 +2876,67 @@ def infer_overload_return_type(

Assumes all of the given targets have argument counts compatible with the caller.
"""

matches: list[CallableType] = []
return_types: list[Type] = []
inferred_types: list[Type] = []
args_contain_any = any(map(has_any_type, arg_types))
type_maps: list[dict[Expression, Type]] = []
args_contain_any = any(map(has_any_type, arg_types))

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
# check for ambiguity due to 'Any' below.
if not args_contain_any:
self.chk.store_types(m)
return ret_type, infer_type
p_infer_type = get_proper_type(infer_type)
if isinstance(p_infer_type, CallableType):
# Prefer inferred types if possible, this will avoid false triggers for
# Any-ambiguity caused by arguments with Any passed to generic overloads.
matches.append(p_infer_type)
else:
matches.append(typ)
return_types.append(ret_type)
inferred_types.append(infer_type)
type_maps.append(m)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
if w.has_new_errors():
continue

# Return early if possible; otherwise record info, so we can
# check for ambiguity due to 'Any' below.
if not args_contain_any:
self.chk.store_types(m)
return ret_type, infer_type

# Prefer inferred types if possible, this will avoid false triggers for
# Any-ambiguity caused by arguments with Any passed to generic overloads.
p = get_proper_type(infer_type)
matches.append(p if isinstance(p, CallableType) else typ)
return_types.append(ret_type)
inferred_types.append(infer_type)
type_maps.append(m)

if not matches:
return None
elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
# An argument of type or containing the type 'Any' caused ambiguity.
# We try returning a precise type if we can. If not, we give up and just return 'Any'.

# An argument of type or containing the type 'Any' caused ambiguity.
# We try returning a precise type if we can. If not, we give up and just return 'Any'.
if any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
if all_same_types(return_types):
self.chk.store_types(type_maps[0])
return return_types[0], inferred_types[0]
elif all_same_types([erase_type(typ) for typ in return_types]):
erased = [erase_type(t) for t in return_types]
if all_same_types(cast(list[Type], erased)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why new cast?

self.chk.store_types(type_maps[0])
return erase_type(return_types[0]), erase_type(inferred_types[0])
else:
return self.check_call(
callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
else:
# Success! No ambiguity; return the first match.
self.chk.store_types(type_maps[0])
return return_types[0], inferred_types[0]
return self.check_call(
callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)

# Success! No ambiguity; return the first match.
self.chk.store_types(type_maps[0])
return return_types[0], inferred_types[0]

def overload_erased_call_targets(
self,
Expand Down