Skip to content

Commit cbcc0f3

Browse files
authored
1 parent d2ab53b commit cbcc0f3

File tree

4 files changed

+141
-49
lines changed

4 files changed

+141
-49
lines changed

src/analysis/forward.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
2626
frule_call = CC.abstract_call_gf_by_type(interp′,
2727
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)
2828
if frule_call.rt !== Const(nothing)
29+
@static if VERSION v"1.11.0-DEV.945"
30+
return CallMeta(primal_call.rt, primal_call.exct, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
31+
else
2932
return CallMeta(primal_call.rt, primal_call.effects, FRuleCallInfo(primal_call.info, frule_call))
33+
end
3034
else
3135
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
3236
end

src/stage1/compiler_utils.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ end
5757
Base.lastindex(x::Core.Compiler.InstructionStream) =
5858
Core.Compiler.length(x)
5959

60-
# Solves an error after https://github.com/JuliaLang/julia/pull/46961
61-
# as does https://github.com/FluxML/IRTools.jl/pull/101
62-
if isdefined(Core.Compiler, :CallInfo)
63-
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
64-
end
65-
66-
6760
"""
6861
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)
6962

src/stage1/recurse.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
using Core.IR
12
using Core.Compiler:
2-
Argument, BasicBlock, CFG, CodeInfo, GotoIfNot, GotoNode, IRCode, IncrementalCompact,
3-
Instruction, MethodInstance, NewInstruction, NewvarNode, OldSSAValue, PhiNode,
4-
ReturnNode, SSAValue, SlotNumber, StmtRange,
3+
BasicBlock, CallInfo, CFG, IRCode, IncrementalCompact, Instruction, NewInstruction,
4+
NoCallInfo, OldSSAValue, StmtRange,
55
bbidxiter, cfg_delete_edge!, cfg_insert_edge!, compute_basic_blocks, complete,
66
construct_domtree, construct_ssa!, domsort_ssa!, finish, insert_node!,
77
insert_node_here!, effect_free_and_nothrow, non_dce_finish!, quoted, retrieve_code_info,
@@ -266,7 +266,7 @@ function optic_transform!(ci, mi, nargs, N)
266266

267267
meta = Expr[]
268268
ir = IRCode(Core.Compiler.InstructionStream(code, Any[],
269-
Any[nothing for i = 1:length(code)],
269+
CallInfo[NoCallInfo() for i = 1:length(code)],
270270
ci.codelocs, UInt8[0 for i = 1:length(code)]), cfg, Core.LineInfoNode[ci.linetable...],
271271
Any[Any for i = 1:2], meta, sptypes(sparams))
272272

src/stage2/abstractinterpret.jl

Lines changed: 133 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using .CC: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
33
getfield_tfunc, _methods_by_ftype, VarTable, nfields_tfunc,
44
ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
55
PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
6-
StmtInfo
6+
StmtInfo, NoCallInfo
77
using Core: PartialStruct
88
using Base.Meta
99

@@ -41,7 +41,11 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
4141
else
4242
rt2 = obtype
4343
end
44+
@static if VERSION v"1.11.0-DEV.945"
45+
return CallMeta(rt2, call.exct, call.effects, RecurseInfo(call.info))
46+
else
4447
return CallMeta(rt2, call.effects, RecurseInfo(call.info))
48+
end
4549
end
4650

4751
# Check if there is a rrule for this function
@@ -56,7 +60,12 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
5660
end
5761
call = abstract_call_gf_by_type(lower_level(interp), ChainRules.rrule, ArgInfo(nothing, rrule_argtypes), rrule_atype, sv, -1)
5862
if call.rt != Const(nothing)
59-
return CallMeta(getfield_tfunc(call.rt, Const(1)), call.effects, RRuleInfo(call.rt, call.info))
63+
newrt = getfield_tfunc(call.rt, Const(1))
64+
@static if VERSION v"1.11.0-DEV.945"
65+
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
66+
else
67+
return CallMeta(newrt, call.exct, call.effects, RRuleInfo(call.rt, call.info))
68+
end
6069
end
6170
end
6271
end
@@ -74,26 +83,39 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
7483
return ret
7584
end
7685

77-
function abstract_accum(interp::AbstractInterpreter, args::Vector{Any}, sv::InferenceState)
78-
args = filter(x->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), args)
86+
function abstract_accum(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
87+
argtypes = filter(@nospecialize(x)->!(widenconst(x) <: Union{ZeroTangent, NoTangent}), argtypes)
7988

80-
if length(args) == 0
81-
return CallMeta(ZeroTangent, Effects(), nothing)
89+
if length(argtypes) == 0
90+
@static if VERSION v"1.11.0-DEV.945"
91+
return CallMeta(ZeroTangent, Any, Effects(), NoCallInfo())
92+
else
93+
return CallMeta(ZeroTangent, Effects(), NoCallInfo())
94+
end
8295
end
8396

84-
if length(args) == 1
85-
return CallMeta(args[1], Effects(), nothing)
97+
if length(argtypes) == 1
98+
@static if VERSION v"1.11.0-DEV.945"
99+
return CallMeta(argtypes[1], Any, Effects(), NoCallInfo())
100+
else
101+
return CallMeta(argtypes[1], Effects(), NoCallInfo())
102+
end
86103
end
87104

88-
rtype = reduce(tmerge, args)
105+
rtype = reduce(tmerge, argtypes)
89106
if widenconst(rtype) <: Tuple
90107
targs = Any[]
91108
for i = 1:nfields_tfunc(rtype).val
92-
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in args], sv).rt)
109+
push!(targs, abstract_accum(interp, Any[getfield_tfunc(arg, Const(i)) for arg in argtypes], sv).rt)
110+
end
111+
rt = tuple_tfunc(targs)
112+
@static if VERSION v"1.11.0-DEV.945"
113+
return CallMeta(rt, Any, Effects(), NoCallInfo())
114+
else
115+
return CallMeta(rt, Effects(), NoCallInfo())
93116
end
94-
return CallMeta(tuple_tfunc(targs), nothing)
95117
end
96-
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), args...],
118+
call = abstract_call(change_level(interp, 0), nothing, Any[typeof(accum), argtypes...],
97119
sv::InferenceState)
98120
return call
99121
end
@@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
249271
ft = argextype(inst.args[1], primal, primal.sptypes)
250272
f = singleton_type(ft)
251273
if isa(f, Core.Builtin)
252-
call = CallMeta(backwards_tfunc(f, primal, inst, Δ), nothing)
274+
rt = backwards_tfunc(f, primal, inst, Δ)
275+
@static if VERSION v"1.11.0-DEV.945"
276+
call = CallMeta(rt, Any, Effects(), NoCallInfo())
277+
else
278+
call = CallMeta(rt, Effects(), NoCallInfo())
279+
end
253280
else
254281
bail!(inst)
255282
continue
@@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
265292
arg = getfield_tfunc(Δ, Const(1))
266293
call = abstract_call(interp, nothing, Any[clos, arg], sv)
267294
# No derivative wrt the functor
268-
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
295+
rt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
296+
@static if VERSION v"1.11.0-DEV.945"
297+
call = CallMeta(rt, Any, Effects(), ReifyInfo(call.info))
298+
else
299+
call = CallMeta(rt, Effects(), ReifyInfo(call.info))
300+
end
269301
else
270302
(level, close) = derive_closure_type(call_info)
271303
call = abstract_call(change_level(interp, level), ArgInfo(nothing, Any[close, Δ]), sv)
@@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
274306

275307
if isa(info, UnionSplitApplyCallInfo)
276308
argts = Any[argextype(inst.args[i], primal, primal.sptypes) for i = 4:length(inst.args)]
277-
call = CallMeta(repackage_apply_rt(info, call.rt, argts),
278-
UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
309+
rt = repackage_apply_rt(info, call.rt, argts)
310+
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
311+
@static if VERSION v"1.11.0-DEV.945"
312+
call = CallMeta(rt, Any, Effects(), newinfo)
313+
else
314+
call = CallMeta(rt, Effects(), newinfo)
315+
end
279316
end
280317

281318
if isa(call_info, ReifyInfo)
282319
new_rt = tuple_tfunc(Any[derive_closure_type(call.info)[2]; call.rt])
283-
call = CallMeta(new_rt, RecurseInfo(call.info))
320+
newinfo = RecurseInfo(call.info)
321+
@static if VERSION v"1.11.0-DEV.945"
322+
call = CallMeta(new_rt, Any, Effects(), newinfo)
323+
else
324+
call = CallMeta(new_rt, Effects(), newinfo)
325+
end
284326
end
285327

286328
if call.rt === Union{}
@@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
312354
accum_call = abstract_accum(interp, this_arg_typs, sv)
313355
if accum_call.rt == Union{}
314356
@show accum_call.rt
315-
return CallMeta(Union{}, false)
357+
@static if VERSION v"1.11.0-DEV.945"
358+
return CallMeta(Union{}, Any, Effects(), NoCallInfo())
359+
else
360+
return CallMeta(Union{}, Effects(), NoCallInfo())
361+
end
316362
end
317363
push!(arg_accums, accum_call)
318364
tup_push!(tup_elemns, accum_call.rt)
319365
end
320366
end
321367

322368
rt = tuple_tfunc(Any[tup_elemns...])
369+
@static if VERSION v"1.11.0-DEV.945"
370+
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
371+
else
323372
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
373+
end
324374
end
325375

326376
function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospecialize(cc_Δ), sv::InferenceState)
@@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
389439

390440
if isa(inst, ReturnNode)
391441
rt = accum_arg(inst.val)
392-
return CallMeta(rt, CompClosInfo(cc, ssa_infos))
442+
@static if VERSION v"1.11.0-DEV.945"
443+
return CallMeta(rt, Any, Effects(), CompClosInfo(cc, ssa_infos))
444+
else
445+
return CallMeta(rt, Effects(), CompClosInfo(cc, ssa_infos))
446+
end
393447
end
394448

395449
args = Any[]
@@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
451505
arg = getfield_tfunc(Δ, Const(2))
452506
call = abstract_call(interp, nothing, Any[clos, arg], sv)
453507
# No derivative wrt the functor
454-
call = CallMeta(tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...]), ReifyInfo(call.info))
508+
newrt = tuple_tfunc(Any[NoTangent; tuple_type_fields(call.rt)...])
509+
@static if VERSION v"1.11.0-DEV.945"
510+
call = CallMeta(newrt, Any, Effects(), ReifyInfo(call.info))
511+
else
512+
call = CallMeta(newrt, Effects(), ReifyInfo(call.info))
513+
end
455514
#error()
456515
else
457516
(level, clos) = derive_closure_type(call_info)
@@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
461520

462521
if isa(call_info, ReifyInfo)
463522
new_rt = tuple_tfunc(Any[call.rt; derive_closure_type(call.info)[2]])
464-
call = CallMeta(new_rt, RecurseInfo())
523+
@static if VERSION v"1.11.0-DEV.945"
524+
call = CallMeta(new_rt, Any, Effects(), RecurseInfo())
525+
else
526+
call = CallMeta(new_rt, Effects(), RecurseInfo())
527+
end
465528
end
466529

467530
if isa(info, UnionSplitApplyCallInfo)
468-
call = CallMeta(call.rt, UnionSplitApplyCallInfo([ApplyCallInfo(call.info)]))
531+
newinfo = UnionSplitApplyCallInfo([ApplyCallInfo(call.info)])
532+
@static if VERSION v"1.11.0-DEV.945"
533+
call = CallMeta(call.rt, call.exct, Effects(), newinfo)
534+
else
535+
call = CallMeta(call.rt, Effects(), newinfo)
536+
end
469537
end
470538

471539
accums[i] = call.rt
@@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
485553
end
486554

487555
function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecialize(Δ), sv::InferenceState)
488-
@show ("enter", pc)
489-
490556
if pc.seq == 1
491557
call = abstract_call(change_level(interp, pc.order), nothing, Any[pc.dual, Δ], sv)
492558
rt = call.rt
493559
@show (pc, Δ, rt)
494-
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
560+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
561+
@static if VERSION v"1.11.0-DEV.945"
562+
return CallMeta(call.rt, call.exct, Effects(), newinfo)
563+
else
564+
return CallMeta(call.rt, Effects(), newinfo)
565+
end
495566
elseif pc.seq == 2
496567
ni = change_level(interp, pc.order)
497568
mi′ = specialize_method(pc.info_below.results.matches[1], true)
@@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
500571
call = infer_comp_closure(ni, cc, Δ, sv)
501572
rt = getfield_tfunc(call.rt, Const(2))
502573
@show (pc, Δ, rt)
503-
return CallMeta(rt,
504-
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
574+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
575+
@static if VERSION v"1.11.0-DEV.945"
576+
return CallMeta(rt, Any, Effects(), newinfo)
577+
else
578+
return CallMeta(rt, Effects(), newinfo)
579+
end
505580
elseif pc.seq == 3
506581
ni = change_level(interp, pc.order)
507582
mi′ = specialize_method(pc.info_carried.info.results.matches[1], true)
@@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
511586
Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
512587
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
513588
@show (pc, Δ, rt)
514-
return CallMeta(rt,
515-
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
589+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
590+
@static if VERSION v"1.11.0-DEV.945"
591+
return CallMeta(rt, Any, Effects(), newinfo)
592+
else
593+
return CallMeta(rt, Effects(), newinfo)
594+
end
516595
elseif mod(pc.seq, 4) == 0
517596
info = pc.info_below
518597
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
519-
520598
# Add back gradient w.r.t. rrule
521599
Δ = tuple_tfunc(Any[NoTangent, tuple_type_fields(Δ)...])
522600
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
523601
rt = getfield_tfunc(call.rt, Const(1))
524602
@show (pc, Δ, rt)
525-
return CallMeta(rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried)))
603+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(2)), call.info, pc.info_carried))
604+
@static if VERSION v"1.11.0-DEV.945"
605+
return CallMeta(rt, Any, Effects(), newinfo)
606+
else
607+
return CallMeta(rt, Effects(), newinfo)
608+
end
526609
elseif mod(pc.seq, 4) == 1
527610
info = pc.info_carried
528611
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
529612
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[pc.dual, Δ])], sv)
530613
rt = call.rt
531614
@show (pc, Δ, rt)
532-
return CallMeta(call.rt, PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
615+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
616+
@static if VERSION v"1.11.0-DEV.945"
617+
return CallMeta(rt, Any, Effects(), newinfo)
618+
else
619+
return CallMeta(rt, Effects(), newinfo)
620+
end
533621
elseif mod(pc.seq, 4) == 2
534622
info = pc.info_below
535623
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
536624
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, Δ], sv)
537625
rt = getfield_tfunc(call.rt, Const(2))
538626
@show (pc, Δ, rt)
539-
return CallMeta(rt,
540-
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried)))
627+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, getfield_tfunc(call.rt, Const(1)), call.info, pc.info_carried))
628+
@static if VERSION v"1.11.0-DEV.945"
629+
return CallMeta(rt, Any, Effects(), newinfo)
630+
else
631+
return CallMeta(rt, Effects(), newinfo)
632+
end
541633
elseif mod(pc.seq, 4) == 3
542634
info = pc.info_carried
543635
clos = AbstractCompClosure(info.clos.order, info.clos.seq + 1, info.clos.primal_info, info.infos)
544636
call = abstract_call(change_level(interp, pc.order), nothing, Any[clos, tuple_tfunc(Any[Δ, pc.dual])], sv)
545637
rt = tuple_tfunc(Any[tuple_type_fields(call.rt)[2:end]...])
546638
@show (pc, Δ, rt)
547-
return CallMeta(rt,
548-
PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below)))
639+
newinfo = PrimClosInfo(PrimClosure(pc.name, pc.order, pc.seq + 1, nothing, call.info, pc.info_below))
640+
@static if VERSION v"1.11.0-DEV.945"
641+
return CallMeta(rt, Any, Effects(), newinfo)
642+
else
643+
return CallMeta(rt, Effects(), newinfo)
644+
end
549645
end
550646
error()
551647
end
@@ -556,8 +652,7 @@ function CC.abstract_call_opaque_closure(interp::ADInterpreter,
556652
if isa(closure.source, AbstractCompClosure)
557653
(;argtypes) = arginfo
558654
if length(argtypes) !== 2
559-
error()
560-
return CallMeta(Union{}, false)
655+
error("bad argtypes")
561656
end
562657
return infer_comp_closure(interp, closure.source, argtypes[2], sv)
563658
elseif isa(closure.source, PrimClosure)

0 commit comments

Comments
 (0)