@@ -3,7 +3,7 @@ using .CC: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
3
3
getfield_tfunc, _methods_by_ftype, VarTable, nfields_tfunc,
4
4
ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
5
5
PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
6
- StmtInfo
6
+ StmtInfo, NoCallInfo
7
7
using Core: PartialStruct
8
8
using Base. Meta
9
9
@@ -41,7 +41,11 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
41
41
else
42
42
rt2 = obtype
43
43
end
44
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
45
+ return CallMeta (rt2, call. exct, call. effects, RecurseInfo (call. info))
46
+ else
44
47
return CallMeta (rt2, call. effects, RecurseInfo (call. info))
48
+ end
45
49
end
46
50
47
51
# Check if there is a rrule for this function
@@ -56,7 +60,12 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
56
60
end
57
61
call = abstract_call_gf_by_type (lower_level (interp), ChainRules. rrule, ArgInfo (nothing , rrule_argtypes), rrule_atype, sv, - 1 )
58
62
if call. rt != Const (nothing )
59
- return CallMeta (getfield_tfunc (call. rt, Const (1 )), call. effects, RRuleInfo (call. rt, call. info))
63
+ newrt = getfield_tfunc (call. rt, Const (1 ))
64
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
65
+ return CallMeta (newrt, call. exct, call. effects, RRuleInfo (call. rt, call. info))
66
+ else
67
+ return CallMeta (newrt, call. exct, call. effects, RRuleInfo (call. rt, call. info))
68
+ end
60
69
end
61
70
end
62
71
end
@@ -74,26 +83,39 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
74
83
return ret
75
84
end
76
85
77
- function abstract_accum (interp:: AbstractInterpreter , args :: Vector{Any} , sv:: InferenceState )
78
- args = filter (x -> ! (widenconst (x) <: Union{ZeroTangent, NoTangent} ), args )
86
+ function abstract_accum (interp:: AbstractInterpreter , argtypes :: Vector{Any} , sv:: InferenceState )
87
+ argtypes = filter (@nospecialize (x) -> ! (widenconst (x) <: Union{ZeroTangent, NoTangent} ), argtypes )
79
88
80
- if length (args) == 0
81
- return CallMeta (ZeroTangent, Effects (), nothing )
89
+ if length (argtypes) == 0
90
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
91
+ return CallMeta (ZeroTangent, Any, Effects (), NoCallInfo ())
92
+ else
93
+ return CallMeta (ZeroTangent, Effects (), NoCallInfo ())
94
+ end
82
95
end
83
96
84
- if length (args) == 1
85
- return CallMeta (args[1 ], Effects (), nothing )
97
+ if length (argtypes) == 1
98
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
99
+ return CallMeta (argtypes[1 ], Any, Effects (), NoCallInfo ())
100
+ else
101
+ return CallMeta (argtypes[1 ], Effects (), NoCallInfo ())
102
+ end
86
103
end
87
104
88
- rtype = reduce (tmerge, args )
105
+ rtype = reduce (tmerge, argtypes )
89
106
if widenconst (rtype) <: Tuple
90
107
targs = Any[]
91
108
for i = 1 : nfields_tfunc (rtype). val
92
- push! (targs, abstract_accum (interp, Any[getfield_tfunc (arg, Const (i)) for arg in args], sv). rt)
109
+ push! (targs, abstract_accum (interp, Any[getfield_tfunc (arg, Const (i)) for arg in argtypes], sv). rt)
110
+ end
111
+ rt = tuple_tfunc (targs)
112
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
113
+ return CallMeta (rt, Any, Effects (), NoCallInfo ())
114
+ else
115
+ return CallMeta (rt, Effects (), NoCallInfo ())
93
116
end
94
- return CallMeta (tuple_tfunc (targs), nothing )
95
117
end
96
- call = abstract_call (change_level (interp, 0 ), nothing , Any[typeof (accum), args ... ],
118
+ call = abstract_call (change_level (interp, 0 ), nothing , Any[typeof (accum), argtypes ... ],
97
119
sv:: InferenceState )
98
120
return call
99
121
end
@@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
249
271
ft = argextype (inst. args[1 ], primal, primal. sptypes)
250
272
f = singleton_type (ft)
251
273
if isa (f, Core. Builtin)
252
- call = CallMeta (backwards_tfunc (f, primal, inst, Δ), nothing )
274
+ rt = backwards_tfunc (f, primal, inst, Δ)
275
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
276
+ call = CallMeta (rt, Any, Effects (), NoCallInfo ())
277
+ else
278
+ call = CallMeta (rt, Effects (), NoCallInfo ())
279
+ end
253
280
else
254
281
bail! (inst)
255
282
continue
@@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
265
292
arg = getfield_tfunc (Δ, Const (1 ))
266
293
call = abstract_call (interp, nothing , Any[clos, arg], sv)
267
294
# No derivative wrt the functor
268
- call = CallMeta (tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ]), ReifyInfo (call. info))
295
+ rt = tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ])
296
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
297
+ call = CallMeta (rt, Any, Effects (), ReifyInfo (call. info))
298
+ else
299
+ call = CallMeta (rt, Effects (), ReifyInfo (call. info))
300
+ end
269
301
else
270
302
(level, close) = derive_closure_type (call_info)
271
303
call = abstract_call (change_level (interp, level), ArgInfo (nothing , Any[close, Δ]), sv)
@@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
274
306
275
307
if isa (info, UnionSplitApplyCallInfo)
276
308
argts = Any[argextype (inst. args[i], primal, primal. sptypes) for i = 4 : length (inst. args)]
277
- call = CallMeta (repackage_apply_rt (info, call. rt, argts),
278
- UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)]))
309
+ rt = repackage_apply_rt (info, call. rt, argts)
310
+ newinfo = UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)])
311
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
312
+ call = CallMeta (rt, Any, Effects (), newinfo)
313
+ else
314
+ call = CallMeta (rt, Effects (), newinfo)
315
+ end
279
316
end
280
317
281
318
if isa (call_info, ReifyInfo)
282
319
new_rt = tuple_tfunc (Any[derive_closure_type (call. info)[2 ]; call. rt])
283
- call = CallMeta (new_rt, RecurseInfo (call. info))
320
+ newinfo = RecurseInfo (call. info)
321
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
322
+ call = CallMeta (new_rt, Any, Effects (), newinfo)
323
+ else
324
+ call = CallMeta (new_rt, Effects (), newinfo)
325
+ end
284
326
end
285
327
286
328
if call. rt === Union{}
@@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
312
354
accum_call = abstract_accum (interp, this_arg_typs, sv)
313
355
if accum_call. rt == Union{}
314
356
@show accum_call. rt
315
- return CallMeta (Union{}, false )
357
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
358
+ return CallMeta (Union{}, Any, Effects (), NoCallInfo ())
359
+ else
360
+ return CallMeta (Union{}, Effects (), NoCallInfo ())
361
+ end
316
362
end
317
363
push! (arg_accums, accum_call)
318
364
tup_push! (tup_elemns, accum_call. rt)
319
365
end
320
366
end
321
367
322
368
rt = tuple_tfunc (Any[tup_elemns... ])
369
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
370
+ return CallMeta (rt, Any, Effects (), CompClosInfo (cc, ssa_infos))
371
+ else
323
372
return CallMeta (rt, Effects (), CompClosInfo (cc, ssa_infos))
373
+ end
324
374
end
325
375
326
376
function infer_cc_forward (interp:: ADInterpreter , cc:: AbstractCompClosure , @nospecialize (cc_Δ), sv:: InferenceState )
@@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
389
439
390
440
if isa (inst, ReturnNode)
391
441
rt = accum_arg (inst. val)
392
- return CallMeta (rt, CompClosInfo (cc, ssa_infos))
442
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
443
+ return CallMeta (rt, Any, Effects (), CompClosInfo (cc, ssa_infos))
444
+ else
445
+ return CallMeta (rt, Effects (), CompClosInfo (cc, ssa_infos))
446
+ end
393
447
end
394
448
395
449
args = Any[]
@@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
451
505
arg = getfield_tfunc (Δ, Const (2 ))
452
506
call = abstract_call (interp, nothing , Any[clos, arg], sv)
453
507
# No derivative wrt the functor
454
- call = CallMeta (tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ]), ReifyInfo (call. info))
508
+ newrt = tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ])
509
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
510
+ call = CallMeta (newrt, Any, Effects (), ReifyInfo (call. info))
511
+ else
512
+ call = CallMeta (newrt, Effects (), ReifyInfo (call. info))
513
+ end
455
514
# error()
456
515
else
457
516
(level, clos) = derive_closure_type (call_info)
@@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
461
520
462
521
if isa (call_info, ReifyInfo)
463
522
new_rt = tuple_tfunc (Any[call. rt; derive_closure_type (call. info)[2 ]])
464
- call = CallMeta (new_rt, RecurseInfo ())
523
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
524
+ call = CallMeta (new_rt, Any, Effects (), RecurseInfo ())
525
+ else
526
+ call = CallMeta (new_rt, Effects (), RecurseInfo ())
527
+ end
465
528
end
466
529
467
530
if isa (info, UnionSplitApplyCallInfo)
468
- call = CallMeta (call. rt, UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)]))
531
+ newinfo = UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)])
532
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
533
+ call = CallMeta (call. rt, call. exct, Effects (), newinfo)
534
+ else
535
+ call = CallMeta (call. rt, Effects (), newinfo)
536
+ end
469
537
end
470
538
471
539
accums[i] = call. rt
@@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
485
553
end
486
554
487
555
function infer_prim_closure (interp:: ADInterpreter , pc:: PrimClosure , @nospecialize (Δ), sv:: InferenceState )
488
- @show (" enter" , pc)
489
-
490
556
if pc. seq == 1
491
557
call = abstract_call (change_level (interp, pc. order), nothing , Any[pc. dual, Δ], sv)
492
558
rt = call. rt
493
559
@show (pc, Δ, rt)
494
- return CallMeta (call. rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
560
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
561
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
562
+ return CallMeta (call. rt, call. exct, Effects (), newinfo)
563
+ else
564
+ return CallMeta (call. rt, Effects (), newinfo)
565
+ end
495
566
elseif pc. seq == 2
496
567
ni = change_level (interp, pc. order)
497
568
mi′ = specialize_method (pc. info_below. results. matches[1 ], true )
@@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
500
571
call = infer_comp_closure (ni, cc, Δ, sv)
501
572
rt = getfield_tfunc (call. rt, Const (2 ))
502
573
@show (pc, Δ, rt)
503
- return CallMeta (rt,
504
- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried)))
574
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried))
575
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
576
+ return CallMeta (rt, Any, Effects (), newinfo)
577
+ else
578
+ return CallMeta (rt, Effects (), newinfo)
579
+ end
505
580
elseif pc. seq == 3
506
581
ni = change_level (interp, pc. order)
507
582
mi′ = specialize_method (pc. info_carried. info. results. matches[1 ], true )
@@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
511
586
Any[clos, tuple_tfunc (Any[Δ, pc. dual])], sv)
512
587
rt = tuple_tfunc (Any[tuple_type_fields (call. rt)[2 : end ]. .. ])
513
588
@show (pc, Δ, rt)
514
- return CallMeta (rt,
515
- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
589
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
590
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
591
+ return CallMeta (rt, Any, Effects (), newinfo)
592
+ else
593
+ return CallMeta (rt, Effects (), newinfo)
594
+ end
516
595
elseif mod (pc. seq, 4 ) == 0
517
596
info = pc. info_below
518
597
clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
519
-
520
598
# Add back gradient w.r.t. rrule
521
599
Δ = tuple_tfunc (Any[NoTangent, tuple_type_fields (Δ)... ])
522
600
call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, Δ], sv)
523
601
rt = getfield_tfunc (call. rt, Const (1 ))
524
602
@show (pc, Δ, rt)
525
- return CallMeta (rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (2 )), call. info, pc. info_carried)))
603
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (2 )), call. info, pc. info_carried))
604
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
605
+ return CallMeta (rt, Any, Effects (), newinfo)
606
+ else
607
+ return CallMeta (rt, Effects (), newinfo)
608
+ end
526
609
elseif mod (pc. seq, 4 ) == 1
527
610
info = pc. info_carried
528
611
clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
529
612
call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, tuple_tfunc (Any[pc. dual, Δ])], sv)
530
613
rt = call. rt
531
614
@show (pc, Δ, rt)
532
- return CallMeta (call. rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
615
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
616
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
617
+ return CallMeta (rt, Any, Effects (), newinfo)
618
+ else
619
+ return CallMeta (rt, Effects (), newinfo)
620
+ end
533
621
elseif mod (pc. seq, 4 ) == 2
534
622
info = pc. info_below
535
623
clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
536
624
call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, Δ], sv)
537
625
rt = getfield_tfunc (call. rt, Const (2 ))
538
626
@show (pc, Δ, rt)
539
- return CallMeta (rt,
540
- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried)))
627
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried))
628
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
629
+ return CallMeta (rt, Any, Effects (), newinfo)
630
+ else
631
+ return CallMeta (rt, Effects (), newinfo)
632
+ end
541
633
elseif mod (pc. seq, 4 ) == 3
542
634
info = pc. info_carried
543
635
clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
544
636
call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, tuple_tfunc (Any[Δ, pc. dual])], sv)
545
637
rt = tuple_tfunc (Any[tuple_type_fields (call. rt)[2 : end ]. .. ])
546
638
@show (pc, Δ, rt)
547
- return CallMeta (rt,
548
- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
639
+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
640
+ @static if VERSION ≥ v " 1.11.0-DEV.945"
641
+ return CallMeta (rt, Any, Effects (), newinfo)
642
+ else
643
+ return CallMeta (rt, Effects (), newinfo)
644
+ end
549
645
end
550
646
error ()
551
647
end
@@ -556,8 +652,7 @@ function CC.abstract_call_opaque_closure(interp::ADInterpreter,
556
652
if isa (closure. source, AbstractCompClosure)
557
653
(;argtypes) = arginfo
558
654
if length (argtypes) != = 2
559
- error ()
560
- return CallMeta (Union{}, false )
655
+ error (" bad argtypes" )
561
656
end
562
657
return infer_comp_closure (interp, closure. source, argtypes[2 ], sv)
563
658
elseif isa (closure. source, PrimClosure)
0 commit comments