From b1d5b923615d8b6d348227e5ae52e909b35400e0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= <roberfi@gmail.com>
Date: Sat, 1 Jun 2024 14:46:41 +0200
Subject: [PATCH 1/3] Fix union inference of generic class and its generic type

---
 mypy/constraints.py                 | 16 ++++++++++++++-
 test-data/unit/check-inference.test | 30 ++++++++++++++++++++++-------
 2 files changed, 38 insertions(+), 8 deletions(-)

diff --git a/mypy/constraints.py b/mypy/constraints.py
index cdfa39ac45f3..922f26a9de17 100644
--- a/mypy/constraints.py
+++ b/mypy/constraints.py
@@ -390,10 +390,24 @@ def _infer_constraints(
         # When the template is a union, we are okay with leaving some
         # type variables indeterminate. This helps with some special
         # cases, though this isn't very principled.
+
+        def _is_item_being_overlaped_by_other(item: Type) -> bool:
+            # It returns true if the item is an argument of other item
+            # that is subtype of the actual type
+            return any(
+                isinstance(p_type := get_proper_type(item_to_compare), Instance)
+                and mypy.subtypes.is_subtype(actual, erase_typevars(p_type))
+                and item in p_type.args
+                for item_to_compare in template.items
+                if item is not item_to_compare
+            )
+
         result = any_constraints(
             [
                 infer_constraints_if_possible(t_item, actual, direction)
-                for t_item in template.items
+                for t_item in [
+                    item for item in template.items if not _is_item_being_overlaped_by_other(item)
+                ]
             ],
             eager=False,
         )
diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test
index 08b53ab16972..ca05cb73c335 100644
--- a/test-data/unit/check-inference.test
+++ b/test-data/unit/check-inference.test
@@ -873,13 +873,7 @@ def g(x: Union[T, List[T]]) -> List[T]: pass
 def h(x: List[str]) -> None: pass
 g('a')() # E: "List[str]" not callable
 
-# The next line is a case where there are multiple ways to satisfy a constraint
-# involving a Union. Either T = List[str] or T = str would turn out to be valid,
-# but mypy doesn't know how to branch on these two options (and potentially have
-# to backtrack later) and defaults to T = Never. The result is an
-# awkward error message. Either a better error message, or simply accepting the
-# call, would be preferable here.
-g(['a']) # E: Argument 1 to "g" has incompatible type "List[str]"; expected "List[Never]"
+g(['a'])
 
 h(g(['a']))
 
@@ -891,6 +885,28 @@ i(b, a, b)
 i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]"
 [builtins fixtures/list.pyi]
 
+[case testUnionInferenceOfGenericClassAndItsGenericType]
+from typing import Generic, TypeVar, Union
+
+T = TypeVar('T')
+
+class GenericClass(Generic[T]):
+    def __init__(self, value: T) -> None:
+        self.value = value
+
+def method_with_union(arg: Union[GenericClass[T], T]) -> GenericClass[T]:
+    if not isinstance(arg, GenericClass):
+        arg = GenericClass(arg)
+    return arg
+
+result_1 = method_with_union(GenericClass("test"))
+reveal_type(result_1) # N: Revealed type is "__main__.GenericClass[builtins.str]"
+
+result_2 = method_with_union("test")
+reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]"
+
+[builtins fixtures/isinstance.pyi]
+
 [case testCallableListJoinInference]
 from typing import Any, Callable
 

From 2314852da781cca51f5306d2cb046de3c88719ce Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= <roberfi@gmail.com>
Date: Sun, 2 Jun 2024 17:13:26 +0200
Subject: [PATCH 2/3] Improve inference of union of generic types when one of
 the types is the generic type of the other

---
 mypy/constraints.py                 | 69 ++++++++++++++++++++++++-----
 test-data/unit/check-inference.test | 27 ++++++++++-
 2 files changed, 84 insertions(+), 12 deletions(-)

diff --git a/mypy/constraints.py b/mypy/constraints.py
index 922f26a9de17..5f4c3c5437b8 100644
--- a/mypy/constraints.py
+++ b/mypy/constraints.py
@@ -271,7 +271,11 @@ def infer_constraints_for_callable(
 
 
 def infer_constraints(
-    template: Type, actual: Type, direction: int, skip_neg_op: bool = False
+    template: Type,
+    actual: Type,
+    direction: int,
+    skip_neg_op: bool = False,
+    can_have_union_overlaping: bool = True,
 ) -> list[Constraint]:
     """Infer type constraints.
 
@@ -311,11 +315,15 @@ def infer_constraints(
         res = _infer_constraints(template, actual, direction, skip_neg_op)
         type_state.inferring.pop()
         return res
-    return _infer_constraints(template, actual, direction, skip_neg_op)
+    return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlaping)
 
 
 def _infer_constraints(
-    template: Type, actual: Type, direction: int, skip_neg_op: bool
+    template: Type,
+    actual: Type,
+    direction: int,
+    skip_neg_op: bool,
+    can_have_union_overlaping: bool = True,
 ) -> list[Constraint]:
     orig_template = template
     template = get_proper_type(template)
@@ -368,8 +376,41 @@ def _infer_constraints(
         return res
     if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
         res = []
+
+        def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool:
+            # There is a special overlaping case, where we have a Union of where two types
+            # are the same, but one of them contains the other.
+            # For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
+            # In this case, only the second one can have overlaping because it contains the other.
+            # So, in case of list[list[int]], second one would be chosen.
+            if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
+                other_items = [o_item for o_item in _actual.items if o_item is not a_item]
+
+                if len(other_items) == 1 and other_items[0] in p_item.args:
+                    return True
+
+                if len(other_items) > 1:
+                    union_args = [
+                        p_arg
+                        for arg in p_item.args
+                        if isinstance(p_arg := get_proper_type(arg), UnionType)
+                    ]
+
+                    for union_arg in union_args:
+                        if all(o_item in union_arg.items for o_item in other_items):
+                            return True
+
+            return False
+
         for a_item in actual.items:
-            res.extend(infer_constraints(orig_template, a_item, direction))
+            res.extend(
+                infer_constraints(
+                    orig_template,
+                    a_item,
+                    direction,
+                    can_have_union_overlaping=_can_have_overlaping(a_item, actual),
+                )
+            )
         return res
 
     # Now the potential subtype is known not to be a Union or a type
@@ -391,22 +432,28 @@ def _infer_constraints(
         # type variables indeterminate. This helps with some special
         # cases, though this isn't very principled.
 
-        def _is_item_being_overlaped_by_other(item: Type) -> bool:
-            # It returns true if the item is an argument of other item
+        def _is_item_overlaping_actual_type(_item: Type) -> bool:
+            # Overlaping occurs when we have a Union where two types are
+            # compatible and the more generic one is chosen.
+            # For example, in Union[T, Sequence[T]], we have to choose
+            # Sequence[T] if actual type is list[int].
+            # This returns true if the item is an argument of other item
             # that is subtype of the actual type
             return any(
-                isinstance(p_type := get_proper_type(item_to_compare), Instance)
-                and mypy.subtypes.is_subtype(actual, erase_typevars(p_type))
-                and item in p_type.args
+                isinstance(p_item_to_compare := get_proper_type(item_to_compare), Instance)
+                and mypy.subtypes.is_subtype(actual, erase_typevars(p_item_to_compare))
+                and _item in p_item_to_compare.args
                 for item_to_compare in template.items
-                if item is not item_to_compare
+                if _item is not item_to_compare
             )
 
         result = any_constraints(
             [
                 infer_constraints_if_possible(t_item, actual, direction)
                 for t_item in [
-                    item for item in template.items if not _is_item_being_overlaped_by_other(item)
+                    item
+                    for item in template.items
+                    if not (can_have_union_overlaping and _is_item_overlaping_actual_type(item))
                 ]
             ],
             eager=False,
diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test
index ca05cb73c335..96d42b47d7ee 100644
--- a/test-data/unit/check-inference.test
+++ b/test-data/unit/check-inference.test
@@ -885,7 +885,7 @@ i(b, a, b)
 i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]"
 [builtins fixtures/list.pyi]
 
-[case testUnionInferenceOfGenericClassAndItsGenericType]
+[case testInferenceOfUnionOfGenericClassAndItsGenericType]
 from typing import Generic, TypeVar, Union
 
 T = TypeVar('T')
@@ -907,6 +907,31 @@ reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]
 
 [builtins fixtures/isinstance.pyi]
 
+[case testInferenceOfUnionOfSequenceOfAnyAndSequenceOfSequence]
+from typing import Sequence, Iterable, TypeVar, Union
+
+T = TypeVar("T")
+S = TypeVar("S")
+
+def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
+    pass
+
+def method(value: Union[Sequence[T], Sequence[Sequence[T]]]) -> None:
+    reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[typing.Sequence[T`-1]]"
+
+[case testInferenceOfUnionOfUnionWithSequenceAndSequenceOfThatUnion]
+from typing import Sequence, Iterable, TypeVar, Union
+
+T = Union[str, Sequence[int]]
+S = TypeVar("S", bound=T)
+
+def sub_method(value: Union[S, Iterable[S]]) -> Iterable[S]:
+    pass
+
+def method(value: Union[T, Sequence[T]]) -> None:
+    reveal_type(sub_method(value)) # N: Revealed type is "typing.Iterable[Union[builtins.str, typing.Sequence[builtins.int]]]"
+
+
 [case testCallableListJoinInference]
 from typing import Any, Callable
 

From 133c49b458e980b5ad154d4050b681fb2ac4eafe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Roberto=20Fern=C3=A1ndez=20Iglesias?= <roberfi@gmail.com>
Date: Sat, 26 Oct 2024 11:57:45 +0200
Subject: [PATCH 3/3] Fix 'overlapping' typo

---
 mypy/constraints.py | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mypy/constraints.py b/mypy/constraints.py
index 1a9b7c9f0ab6..e580aca80f24 100644
--- a/mypy/constraints.py
+++ b/mypy/constraints.py
@@ -278,7 +278,7 @@ def infer_constraints(
     actual: Type,
     direction: int,
     skip_neg_op: bool = False,
-    can_have_union_overlaping: bool = True,
+    can_have_union_overlapping: bool = True,
 ) -> list[Constraint]:
     """Infer type constraints.
 
@@ -318,7 +318,7 @@ def infer_constraints(
         res = _infer_constraints(template, actual, direction, skip_neg_op)
         type_state.inferring.pop()
         return res
-    return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlaping)
+    return _infer_constraints(template, actual, direction, skip_neg_op, can_have_union_overlapping)
 
 
 def _infer_constraints(
@@ -326,7 +326,7 @@ def _infer_constraints(
     actual: Type,
     direction: int,
     skip_neg_op: bool,
-    can_have_union_overlaping: bool = True,
+    can_have_union_overlapping: bool = True,
 ) -> list[Constraint]:
     orig_template = template
     template = get_proper_type(template)
@@ -380,11 +380,11 @@ def _infer_constraints(
     if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
         res = []
 
-        def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool:
-            # There is a special overlaping case, where we have a Union of where two types
+        def _can_have_overlapping(_item: Type, _actual: UnionType) -> bool:
+            # There is a special overlapping case, where we have a Union of where two types
             # are the same, but one of them contains the other.
             # For example, we have Union[Sequence[T], Sequence[Sequence[T]]]
-            # In this case, only the second one can have overlaping because it contains the other.
+            # In this case, only the second one can have overlapping because it contains the other.
             # So, in case of list[list[int]], second one would be chosen.
             if isinstance(p_item := get_proper_type(_item), Instance) and p_item.args:
                 other_items = [o_item for o_item in _actual.items if o_item is not a_item]
@@ -411,7 +411,7 @@ def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool:
                     orig_template,
                     a_item,
                     direction,
-                    can_have_union_overlaping=_can_have_overlaping(a_item, actual),
+                    can_have_union_overlapping=_can_have_overlapping(a_item, actual),
                 )
             )
         return res
@@ -435,8 +435,8 @@ def _can_have_overlaping(_item: Type, _actual: UnionType) -> bool:
         # type variables indeterminate. This helps with some special
         # cases, though this isn't very principled.
 
-        def _is_item_overlaping_actual_type(_item: Type) -> bool:
-            # Overlaping occurs when we have a Union where two types are
+        def _is_item_overlapping_actual_type(_item: Type) -> bool:
+            # Overlapping occurs when we have a Union where two types are
             # compatible and the more generic one is chosen.
             # For example, in Union[T, Sequence[T]], we have to choose
             # Sequence[T] if actual type is list[int].
@@ -456,7 +456,7 @@ def _is_item_overlaping_actual_type(_item: Type) -> bool:
                 for t_item in [
                     item
                     for item in template.items
-                    if not (can_have_union_overlaping and _is_item_overlaping_actual_type(item))
+                    if not (can_have_union_overlapping and _is_item_overlapping_actual_type(item))
                 ]
             ],
             eager=False,