Skip to content

Commit 4a10319

Browse files
committed
Support variadic functions.
1 parent 4e461ec commit 4a10319

File tree

3 files changed

+39
-30
lines changed

3 files changed

+39
-30
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
77
### Changed
88
* Nested literal types supported
99
* Variable-length and empty tuples supported
10+
* Variadic functions supported
1011

1112
## [2.0.2](https://pypi.org/project/multimethod/2.0.2/) - 2025-11-17
1213
### Changed

multimethod/__init__.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import types
99
import typing
1010
from collections.abc import Callable, Iterable, Iterator, Mapping
11-
from typing import Any, Literal, TypeVar, Union, get_type_hints
11+
from typing import Any, TypeVar, Union, get_type_hints
1212

1313

1414
class DispatchError(TypeError): ... # pragma: no branch
@@ -45,34 +45,33 @@ class subtype(abc.ABCMeta):
4545
__origin__: type
4646
__args__: tuple
4747

48-
def __new__(cls, tp, *args):
48+
def __new__(cls, tp):
4949
match tp:
5050
case typing.Any:
5151
return object
5252
case subtype(): # If already a subtype, return it directly
5353
return tp
5454
case typing.NewType():
55-
return cls(tp.__supertype__, *args)
55+
return cls(tp.__supertype__)
5656
case TypeVar():
57-
return cls(Union[tp.__constraints__], *args) if tp.__constraints__ else object
57+
return cls(Union[tp.__constraints__]) if tp.__constraints__ else object
5858
case typing._AnnotatedAlias():
59-
return cls(tp.__origin__, *args)
59+
return cls(tp.__origin__)
6060
if hasattr(typing, 'TypeAliasType') and isinstance(tp, typing.TypeAliasType):
61-
return cls(tp.__value__, *args)
61+
return cls(tp.__value__)
6262
origin = get_origin(tp) or tp
63-
args = tuple(map(cls, get_args(tp) or args))
63+
args = tuple(map(cls, get_args(tp)))
6464
if set(args) <= {object} and (origin is not tuple or tp is tuple):
6565
return origin
6666
bases = (origin,) if type(origin) in (type, abc.ABCMeta) else ()
67-
if origin is Literal:
68-
bases = (cls(Union[tuple(map(type, args))]),)
69-
if origin is Union or isinstance(tp, types.UnionType):
70-
origin = types.UnionType
71-
bases = common_bases(*args)[:1]
72-
if bases[0] in args:
73-
return bases[0]
74-
if origin is Callable and args[:1] == (...,):
75-
args = args[1:]
67+
match origin:
68+
case typing.Literal:
69+
bases = (cls(Union[tuple(map(type, args))]),)
70+
case typing.Union | types.UnionType:
71+
origin = types.UnionType
72+
bases = common_bases(*args)[:1]
73+
if bases[0] in args:
74+
return bases[0]
7675
namespace = {'__origin__': origin, '__args__': args}
7776
return type.__new__(cls, str(tp), bases, namespace)
7877

@@ -104,10 +103,14 @@ def __subclasscheck__(self, subclass):
104103
param = self.__args__[0]
105104
return all(arg is ... or issubclass(arg, param) for arg in args)
106105
case collections.abc.Callable:
106+
params = self.__args__[:-1]
107107
return (
108108
origin is Callable
109-
and signature(self.__args__[-1:]) <= signature(args[-1:]) # covariant return
110-
and signature(args[:-1]) <= signature(self.__args__[:-1]) # contravariant args
109+
and issubclass(args[-1], self.__args__[-1]) # covariant return
110+
and (
111+
... in params
112+
or (... not in args and signature(args[:-1]) <= signature(params))
113+
) # contravariant args
111114
)
112115
return ( # check args first to avoid recursion error: python/cpython#73407
113116
len(args) == len(self.__args__)
@@ -116,21 +119,24 @@ def __subclasscheck__(self, subclass):
116119
)
117120

118121
def __instancecheck__(self, instance):
119-
if self.__origin__ is Literal:
120-
return any(type(arg) is type(instance) and arg == instance for arg in self.__args__)
121-
if self.__origin__ is types.UnionType:
122-
return isinstance(instance, self.__args__)
123-
if hasattr(instance, '__orig_class__'): # user-defined generic type
122+
match self.__origin__:
123+
case typing.Literal:
124+
return any(type(arg) is type(instance) and arg == instance for arg in self.__args__)
125+
case types.UnionType:
126+
return isinstance(instance, self.__args__)
127+
case builtins.type:
128+
if isinstance(instance, typing.GenericAlias):
129+
return issubclass(subtype(instance), self.__args__)
130+
return inspect.isclass(instance) and issubclass(instance, self.__args__)
131+
if isinstance(instance, typing.Generic): # user-defined generic type
124132
return issubclass(instance.__orig_class__, self)
125-
if self.__origin__ is type: # a class argument is expected
126-
if isinstance(instance, types.GenericAlias):
127-
return issubclass(subtype(instance), self.__args__)
128-
return inspect.isclass(instance) and issubclass(instance, self.__args__)
129133
if not isinstance(instance, self.__origin__) or isinstance(instance, Iterator):
130134
return False
131135
if self.__origin__ is Callable:
132-
return issubclass(subtype(Callable, *get_type_hints(instance).values()), self)
133-
if self.__origin__ is tuple and self.__args__[-1:] != (...,):
136+
hints = get_type_hints(instance)
137+
args = [hints.get(name, object) for name in inspect.signature(instance).parameters]
138+
return issubclass(Callable[args, hints.get('return', object)], self)
139+
if self.__origin__ is tuple and ... not in self.__args__:
134140
if len(instance) != len(self.__args__):
135141
return False
136142
elif issubclass(self, Mapping):

tests/test_subscripts.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def f(arg: bool) -> int: ...
111111

112112
def g(arg: int) -> bool: ...
113113

114-
def h(arg) -> bool: ...
114+
def h(arg: float) -> bool: ...
115115

116116
@multimethod
117117
def func(arg: Callable[[bool], bool]):
@@ -134,6 +134,8 @@ def _(arg: Sequence[Callable[[bool], bool]]):
134134
assert func(g) == 'g'
135135
assert func([g]) == 'g0'
136136
assert func(h) is ...
137+
assert issubclass(Callable[[int], int], subtype(Callable[..., int]))
138+
assert not issubclass(Callable[..., int], subtype(Callable[[int], int]))
137139

138140

139141
def test_final():

0 commit comments

Comments
 (0)