88import types
99import typing
1010from collections .abc import Callable , Iterable , Iterator , Mapping
11- from typing import Any , Literal , TypeVar , Union , get_type_hints
11+ from typing import Any , TypeVar , Union , get_type_hints
1212
1313
1414class DispatchError (TypeError ): ... # pragma: no branch
@@ -45,34 +45,33 @@ class subtype(abc.ABCMeta):
4545 __origin__ : type
4646 __args__ : tuple
4747
48- def __new__ (cls , tp , * args ):
48+ def __new__ (cls , tp ):
4949 match tp :
5050 case typing .Any :
5151 return object
5252 case subtype (): # If already a subtype, return it directly
5353 return tp
5454 case typing .NewType ():
55- return cls (tp .__supertype__ , * args )
55+ return cls (tp .__supertype__ )
5656 case TypeVar ():
57- return cls (Union [tp .__constraints__ ], * args ) if tp .__constraints__ else object
57+ return cls (Union [tp .__constraints__ ]) if tp .__constraints__ else object
5858 case typing ._AnnotatedAlias ():
59- return cls (tp .__origin__ , * args )
59+ return cls (tp .__origin__ )
6060 if hasattr (typing , 'TypeAliasType' ) and isinstance (tp , typing .TypeAliasType ):
61- return cls (tp .__value__ , * args )
61+ return cls (tp .__value__ )
6262 origin = get_origin (tp ) or tp
63- args = tuple (map (cls , get_args (tp ) or args ))
63+ args = tuple (map (cls , get_args (tp )))
6464 if set (args ) <= {object } and (origin is not tuple or tp is tuple ):
6565 return origin
6666 bases = (origin ,) if type (origin ) in (type , abc .ABCMeta ) else ()
67- if origin is Literal :
68- bases = (cls (Union [tuple (map (type , args ))]),)
69- if origin is Union or isinstance (tp , types .UnionType ):
70- origin = types .UnionType
71- bases = common_bases (* args )[:1 ]
72- if bases [0 ] in args :
73- return bases [0 ]
74- if origin is Callable and args [:1 ] == (...,):
75- args = args [1 :]
67+ match origin :
68+ case typing .Literal :
69+ bases = (cls (Union [tuple (map (type , args ))]),)
70+ case typing .Union | types .UnionType :
71+ origin = types .UnionType
72+ bases = common_bases (* args )[:1 ]
73+ if bases [0 ] in args :
74+ return bases [0 ]
7675 namespace = {'__origin__' : origin , '__args__' : args }
7776 return type .__new__ (cls , str (tp ), bases , namespace )
7877
@@ -104,10 +103,14 @@ def __subclasscheck__(self, subclass):
104103 param = self .__args__ [0 ]
105104 return all (arg is ... or issubclass (arg , param ) for arg in args )
106105 case collections .abc .Callable :
106+ params = self .__args__ [:- 1 ]
107107 return (
108108 origin is Callable
109- and signature (self .__args__ [- 1 :]) <= signature (args [- 1 :]) # covariant return
110- and signature (args [:- 1 ]) <= signature (self .__args__ [:- 1 ]) # contravariant args
109+ and issubclass (args [- 1 ], self .__args__ [- 1 ]) # covariant return
110+ and (
111+ ... in params
112+ or (... not in args and signature (args [:- 1 ]) <= signature (params ))
113+ ) # contravariant args
111114 )
112115 return ( # check args first to avoid recursion error: python/cpython#73407
113116 len (args ) == len (self .__args__ )
@@ -116,21 +119,24 @@ def __subclasscheck__(self, subclass):
116119 )
117120
118121 def __instancecheck__ (self , instance ):
119- if self .__origin__ is Literal :
120- return any (type (arg ) is type (instance ) and arg == instance for arg in self .__args__ )
121- if self .__origin__ is types .UnionType :
122- return isinstance (instance , self .__args__ )
123- if hasattr (instance , '__orig_class__' ): # user-defined generic type
122+ match self .__origin__ :
123+ case typing .Literal :
124+ return any (type (arg ) is type (instance ) and arg == instance for arg in self .__args__ )
125+ case types .UnionType :
126+ return isinstance (instance , self .__args__ )
127+ case builtins .type :
128+ if isinstance (instance , typing .GenericAlias ):
129+ return issubclass (subtype (instance ), self .__args__ )
130+ return inspect .isclass (instance ) and issubclass (instance , self .__args__ )
131+ if isinstance (instance , typing .Generic ): # user-defined generic type
124132 return issubclass (instance .__orig_class__ , self )
125- if self .__origin__ is type : # a class argument is expected
126- if isinstance (instance , types .GenericAlias ):
127- return issubclass (subtype (instance ), self .__args__ )
128- return inspect .isclass (instance ) and issubclass (instance , self .__args__ )
129133 if not isinstance (instance , self .__origin__ ) or isinstance (instance , Iterator ):
130134 return False
131135 if self .__origin__ is Callable :
132- return issubclass (subtype (Callable , * get_type_hints (instance ).values ()), self )
133- if self .__origin__ is tuple and self .__args__ [- 1 :] != (...,):
136+ hints = get_type_hints (instance )
137+ args = [hints .get (name , object ) for name in inspect .signature (instance ).parameters ]
138+ return issubclass (Callable [args , hints .get ('return' , object )], self )
139+ if self .__origin__ is tuple and ... not in self .__args__ :
134140 if len (instance ) != len (self .__args__ ):
135141 return False
136142 elif issubclass (self , Mapping ):
0 commit comments