25
25
UninhabitedType ,
26
26
AnyType ,
27
27
TypeOfAny ,
28
+ get_proper_type ,
29
+ get_proper_types ,
28
30
)
29
31
from mypy .typeops import make_simplified_union
30
32
from mypy .checker import TypeChecker
@@ -68,11 +70,10 @@ def args_invariant_decorator_callback(ctx: FunctionContext) -> Type:
68
70
"""
69
71
# (adapted from the @contextmanager support in mypy's builtin plugin)
70
72
if ctx .arg_types and len (ctx .arg_types [0 ]) == 1 :
71
- arg_type = ctx .arg_types [0 ][0 ]
72
- if isinstance (arg_type , CallableType ) and isinstance (
73
- ctx .default_return_type , CallableType
74
- ):
75
- return ctx .default_return_type .copy_modified (
73
+ arg_type = get_proper_type (ctx .arg_types [0 ][0 ])
74
+ ret_type = get_proper_type (ctx .default_return_type )
75
+ if isinstance (arg_type , CallableType ) and isinstance (ret_type , CallableType ):
76
+ return ret_type .copy_modified (
76
77
arg_types = arg_type .arg_types ,
77
78
arg_kinds = arg_type .arg_kinds ,
78
79
arg_names = arg_type .arg_names ,
@@ -166,32 +167,34 @@ def decode_agen_types_from_return_type(
166
167
"""
167
168
168
169
arms = [] # type: Sequence[Type]
169
- if isinstance (original_async_return_type , UnionType ):
170
- arms = original_async_return_type .items
170
+ resolved_async_return_type = get_proper_type (original_async_return_type )
171
+ if isinstance (resolved_async_return_type , UnionType ):
172
+ arms = resolved_async_return_type .items
171
173
else :
172
174
arms = [original_async_return_type ]
173
175
yield_type = None # type: Optional[Type]
174
176
send_type = None # type: Optional[Type]
175
177
other_arms = [] # type: List[Type]
176
178
try :
177
- for arm in arms :
179
+ for orig_arm in arms :
180
+ arm = get_proper_type (orig_arm )
178
181
if isinstance (arm , Instance ):
179
- if arm .type .fullname () == "trio_typing.YieldType" :
182
+ if arm .type .fullname == "trio_typing.YieldType" :
180
183
if len (arm .args ) != 1 :
181
184
raise ValueError ("YieldType must take one argument" )
182
185
if yield_type is not None :
183
186
raise ValueError ("YieldType specified multiple times" )
184
187
yield_type = arm .args [0 ]
185
- elif arm .type .fullname () == "trio_typing.SendType" :
188
+ elif arm .type .fullname == "trio_typing.SendType" :
186
189
if len (arm .args ) != 1 :
187
190
raise ValueError ("SendType must take one argument" )
188
191
if send_type is not None :
189
192
raise ValueError ("SendType specified multiple times" )
190
193
send_type = arm .args [0 ]
191
194
else :
192
- other_arms .append (arm )
195
+ other_arms .append (orig_arm )
193
196
else :
194
- other_arms .append (arm )
197
+ other_arms .append (orig_arm )
195
198
except ValueError as ex :
196
199
ctx .api .fail ("invalid @async_generator return type: {}" .format (ex ), ctx .context )
197
200
return (
@@ -248,19 +251,18 @@ async def example() -> Union[str, YieldType[bool], SendType[int]]:
248
251
# Apply the common logic to not change the arguments of the
249
252
# decorated function
250
253
new_return_type = args_invariant_decorator_callback (ctx )
254
+ if not isinstance (new_return_type , CallableType ):
255
+ return new_return_type
256
+ agen_return_type = get_proper_type (new_return_type .ret_type )
251
257
if (
252
- isinstance (new_return_type , CallableType )
253
- and isinstance (new_return_type .ret_type , Instance )
254
- and new_return_type .ret_type .type .fullname ()
255
- == ("trio_typing.CompatAsyncGenerator" )
256
- and len (new_return_type .ret_type .args ) == 3
258
+ isinstance (agen_return_type , Instance )
259
+ and agen_return_type .type .fullname == "trio_typing.CompatAsyncGenerator"
260
+ and len (agen_return_type .args ) == 3
257
261
):
258
262
return new_return_type .copy_modified (
259
- ret_type = new_return_type . ret_type .copy_modified (
263
+ ret_type = agen_return_type .copy_modified (
260
264
args = list (
261
- decode_agen_types_from_return_type (
262
- ctx , new_return_type .ret_type .args [2 ]
263
- )
265
+ decode_agen_types_from_return_type (ctx , agen_return_type .args [2 ])
264
266
)
265
267
)
266
268
)
@@ -288,10 +290,14 @@ def decode_enclosing_agen_types(ctx: FunctionContext) -> Tuple[Type, Type]:
288
290
)
289
291
return AnyType (TypeOfAny .from_error ), AnyType (TypeOfAny .from_error )
290
292
293
+ # The enclosing function type Callable[...] and its return type
294
+ # Coroutine[...] were both produced by mypy, rather than typed by
295
+ # the user, so they can't be type aliases; thus there's no need to
296
+ # use get_proper_type() here.
291
297
if (
292
298
isinstance (enclosing_func .type , CallableType )
293
299
and isinstance (enclosing_func .type .ret_type , Instance )
294
- and enclosing_func .type .ret_type .type .fullname () == "typing.Coroutine"
300
+ and enclosing_func .type .ret_type .type .fullname == "typing.Coroutine"
295
301
and len (enclosing_func .type .ret_type .args ) == 3
296
302
):
297
303
yield_type , send_type , _ = decode_agen_types_from_return_type (
@@ -334,7 +340,7 @@ def yield_callback(ctx: FunctionContext) -> Type:
334
340
def yield_from_callback (ctx : FunctionContext ) -> Type :
335
341
"""Provide a better typecheck for yield_from_()."""
336
342
if ctx .arg_types and len (ctx .arg_types [0 ]) == 1 :
337
- arg_type = ctx .arg_types [0 ][0 ]
343
+ arg_type = get_proper_type ( ctx .arg_types [0 ][0 ])
338
344
else :
339
345
return ctx .default_return_type
340
346
@@ -345,7 +351,7 @@ def yield_from_callback(ctx: FunctionContext) -> Type:
345
351
346
352
if (
347
353
isinstance (arg_type , Instance )
348
- and arg_type .type .fullname ()
354
+ and arg_type .type .fullname
349
355
in (
350
356
"trio_typing.CompatAsyncGenerator" ,
351
357
"trio_typing.AsyncGenerator" ,
@@ -386,11 +392,12 @@ def started_callback(ctx: MethodContext) -> Type:
386
392
"""Raise an error if task_status.started() is called without an argument
387
393
and the TaskStatus is not declared to accept a result of type None.
388
394
"""
395
+ self_type = get_proper_type (ctx .type )
389
396
if (
390
397
(not ctx .arg_types or not ctx .arg_types [0 ])
391
- and isinstance (ctx . type , Instance )
392
- and ctx . type .args
393
- and not isinstance (ctx . type . args [0 ], NoneTyp )
398
+ and isinstance (self_type , Instance )
399
+ and self_type .args
400
+ and not isinstance (get_proper_type ( self_type . args [0 ]) , NoneTyp )
394
401
):
395
402
ctx .api .fail (
396
403
"TaskStatus.started() requires an argument for types other than "
@@ -446,21 +453,22 @@ def start_soon(
446
453
447
454
"""
448
455
try :
449
- if (
450
- not ctx .arg_types
451
- or len (ctx .arg_types [0 ]) != 1
452
- or not isinstance (ctx .arg_types [0 ][0 ], CallableType )
453
- or not isinstance (ctx .default_return_type , CallableType )
456
+ if not ctx .arg_types or len (ctx .arg_types [0 ]) != 1 :
457
+ raise ValueError ("must be used as a decorator" )
458
+
459
+ fn_type = get_proper_type (ctx .arg_types [0 ][0 ])
460
+ if not isinstance (fn_type , CallableType ) or not isinstance (
461
+ get_proper_type (ctx .default_return_type ), CallableType
454
462
):
455
463
raise ValueError ("must be used as a decorator" )
456
464
457
- fn_type = ctx .arg_types [0 ][0 ] # type: CallableType
458
465
callable_idx = - 1 # index in function arguments of the callable
459
466
callable_args_idx = - 1 # index in callable arguments of the StarArgs
460
467
callable_ty = None # type: Optional[CallableType]
461
468
args_idx = - 1 # index in function arguments of the StarArgs
462
469
463
470
for idx , (kind , ty ) in enumerate (zip (fn_type .arg_kinds , fn_type .arg_types )):
471
+ ty = get_proper_type (ty )
464
472
if isinstance (ty , AnyType ) and kind == ARG_STAR :
465
473
assert args_idx == - 1
466
474
args_idx = idx
@@ -469,7 +477,7 @@ def start_soon(
469
477
# into Callable[[VarArg()], T]
470
478
# (the union makes it not fail when the plugin is not being used)
471
479
if isinstance (ty , UnionType ):
472
- for arm in ty .items :
480
+ for arm in get_proper_types ( ty .items ) :
473
481
if (
474
482
isinstance (arm , CallableType )
475
483
and not arm .is_ellipsis_args
@@ -481,6 +489,7 @@ def start_soon(
481
489
continue
482
490
483
491
for idx_ , (kind_ , ty_ ) in enumerate (zip (ty .arg_kinds , ty .arg_types )):
492
+ ty_ = get_proper_type (ty_ )
484
493
if isinstance (ty_ , AnyType ) and kind_ == ARG_STAR :
485
494
if callable_idx != - 1 :
486
495
raise ValueError (
0 commit comments