Skip to content

Infer type of generic class from return type of (optionally awaitable) callable passed to constructor #19143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
DouweM opened this issue May 23, 2025 · 6 comments
Labels
bug mypy got something wrong topic-inference When to infer types or require explicit annotations

Comments

@DouweM
Copy link

DouweM commented May 23, 2025

I've confirmed the following works in pyright and pyrefly - but not in mypy:

from dataclasses import dataclass

from typing_extensions import (
    Awaitable,
    Callable,
    Generic,
    TypeVar,
    assert_type,
)

T = TypeVar("T")


@dataclass
class Agent(Generic[T]):
    output_type: Callable[..., T] | Callable[..., Awaitable[T]]


async def coro() -> bool:
    return True


def func() -> int:
    return 1


# works
assert_type(Agent(func), Agent[int])

# mypy - error: Argument 1 to "Agent" has incompatible type "Callable[[], Coroutine[Any, Any, bool]]"; expected "Callable[..., Never] | Callable[..., Awaitable[Never]]"  [arg-type]
coro_agent = Agent(coro)
# pyright, pyrefly - works
# mypy - error: Expression is of type "Agent[Any]", not "Agent[bool]"
assert_type(coro_agent, Agent[bool])

# works
assert_type(Agent[bool](coro), Agent[bool])

I want T to be inferred as the ultimate return type of the awaitable if an async function is passed rather than a regular one, but I suppose it's ambiguous which side of the union is the best match.

It would be great to see this work in mypy, but I'm also open to suggestions to do this in a less ambiguous way!

@DouweM DouweM added the bug mypy got something wrong label May 23, 2025
@A5rocks
Copy link
Collaborator

A5rocks commented May 23, 2025

As you note this is ambiguous. As a workaround overloads work.

I'm not sure if there's any principled way around this. I have been thinking about a "strict coloring" mode which treats async functions as different types than non-async ones, but that would be stricter than you would like and also not everyone will enable it.

Maybe mypy should special case specifically T | Awaitable[T] since that's the only case where I've seen this.

@A5rocks A5rocks added the topic-inference When to infer types or require explicit annotations label May 23, 2025
@DouweM
Copy link
Author

DouweM commented May 23, 2025

@A5rocks I appreciate the quick response.

Naively, I'd imagine a general rule like "in case of multiple possible matches, pick the most specific one", which may be what pyright is doing (note that I haven't looked at the implementation).

Special casing T | Awaitable[T] would work for me, but I'm curious why we couldn't do that for any T | Foo[T] when given Foo[T]. I'd expect this to work, for example:

from typing import Sequence, TypeVar, assert_type

T = TypeVar("T")


def ensure_sequence[T](x: T | Sequence[T]) -> Sequence[T]:
    if isinstance(x, Sequence):
        return x
    return [x]


assert_type(ensure_sequence(1), Sequence[int])
assert_type(ensure_sequence([1, 2, 3]), Sequence[int])

mypy currently says this:

error: Expression is of type "Sequence[Never]", not "Sequence[int]"  [assert-type]
error: Argument 1 to "ensure_sequence" has incompatible type "list[int]"; expected "Sequence[Never]"  [arg-type]

Note that pyright doesn't like this either, so maybe it is special casing T | Awaitable[T]. It complains about x on line return x, but it does let the assert_type pass:

error: Return type, "Sequence[Unknown]* | Sequence[T@ensure_sequence]", is partially unknown (reportUnknownVariableType)

@A5rocks
Copy link
Collaborator

A5rocks commented May 23, 2025

in case of multiple possible matches, pick the most specific one

I imagine this wouldn't do well if there's multiple possible matches with same specificity, or even something like:

class A(Protocol[T]):
  a: T

def f(x: T | A[T]) -> T: ...

(You could imagine A as Awaitable and a: T as def __await__(self) -> T (iirc?) if you like)

However it is an improvement in some cases so if we can isolate those that sounds fine. But also if we're adding special cases I would rather being specific eg only special casing T | Awaitable[T].

Maybe a better method is tracking the number of levels above the typevar and choosing the highest one? Or maybe discarding conflicting constraints in order of the union? Neither sound very performant of course.

@DouweM
Copy link
Author

DouweM commented May 23, 2025

@A5rocks Good point, a special case sounds reasonable then. Thanks for considering this!

@A5rocks
Copy link
Collaborator

A5rocks commented May 25, 2025

And BTW I saw you misinterpreted me in the comment for the PR: __init__ can be overloaded.

@DouweM
Copy link
Author

DouweM commented May 26, 2025

@A5rocks As don't think that'd work with the real OutputType, with the Callable[..., T | Awaitable[T]] nested a few levels down:

T_co = TypeVar('T_co', covariant=True)
# output_type=Type or output_type=function or output_type=object.method
SimpleOutputType = TypeAliasType(
    'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,)
)
# output_type=ToolOutput(<see above>) or <see above>
SimpleOutputTypeOrMarker = TypeAliasType(
    'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,)
)
# output_type=<see above> or [<see above>, ...]
OutputType = TypeAliasType(
    'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,)
)

That means output_type can be a list of types and regular functions and async functions:

def func(x: str, y: int) -> str:
    return f'{x} {y}'


async def coro(x: int, y: int) -> int:
    return x * y

complex_output_agent = Agent(output_type=[Foo, Bar, func, coro])
assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int])

Note that the use of Sequence[...] with type[T] also has us run into #19142.

To use overloads, I think I'd need to define some new marker class like OutputFunc with overloads for regular functions and async functions. That's definitely an option, but would make the API a bit less clean, so if mypy is planning to fix this issue and the Sequence[type[T]] one, I'd rather keep it like this (which already works with pyright).

Is there another option I'm missing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug mypy got something wrong topic-inference When to infer types or require explicit annotations
Projects
None yet
Development

No branches or pull requests

2 participants