Skip to content

Commit 374f92b

Browse files
aviateskKeno
andauthored
adjustments to the latest master (#284)
Also fixes a bunch of tests. --------- Co-authored-by: Keno Fischer <[email protected]>
1 parent a444b7f commit 374f92b

File tree

15 files changed

+311
-201
lines changed

15 files changed

+311
-201
lines changed

src/Diffractor.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module Diffractor
22

3+
export ∂⃖, gradient
4+
35
using StructArrays
46
using PrecompileTools
57

6-
export ∂⃖, gradient
7-
88
const CC = Core.Compiler
9+
using Core.IR
910

1011
@static if VERSION v"1.11.0-DEV.1498"
1112
import .CC: get_inference_world
@@ -33,7 +34,6 @@ end
3334
include("stage2/tfuncs.jl")
3435
include("stage2/forward.jl")
3536

36-
include("codegen/forward.jl")
3737
include("analysis/forward.jl")
3838
include("codegen/forward_demand.jl")
3939
include("codegen/reverse.jl")
@@ -48,4 +48,4 @@ end
4848
include("AbstractDifferentiation.jl")
4949
end
5050

51-
end
51+
end # module Diffractor

src/codegen/forward.jl

Lines changed: 0 additions & 117 deletions
This file was deleted.

src/codegen/reverse.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
# Codegen shared by both stage1 and stage2
22

3-
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, cis, revs...)
3+
function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci, revs...)
44
if interp !== nothing
5-
cis.inferred = true
6-
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
7-
typ, Union{}, cis.rettype, @__MODULE__, cis, lno.line, lno.file, meth_nargs, isva, ()).source
8-
return Expr(:new_opaque_closure, typ, Union{}, Any,
9-
ocm, revs...)
5+
@static if VERSION v"1.12.0-DEV.15"
6+
rettype = Any # ci.rettype # TODO revisit
7+
else
8+
ci.inferred = true
9+
rettype = ci.rettype
10+
end
11+
@static if VERSION v"1.12.0-DEV.15"
12+
ocm = Core.OpaqueClosure(ci; rettype, nargs=meth_nargs, isva, sig=typ).source
13+
else
14+
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
15+
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
16+
end
17+
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
1018
else
1119
oc_nargs = Int64(meth_nargs)
1220
Expr(:new_opaque_closure, typ, Union{}, Any,
13-
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, cis), revs...)
21+
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
1422
end
1523
end
1624

@@ -107,8 +115,12 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
107115
opaque_ci.slotnames = [Symbol("#oc#"), ci.slotnames...]
108116
opaque_ci.slotflags = UInt8[0, ci.slotflags...]
109117
end
110-
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
111-
opaque_ci.inferred = false
118+
@static if VERSION v"1.12.0-DEV.173"
119+
opaque_ci.debuginfo = ci.debuginfo
120+
else
121+
opaque_ci.linetable = Core.LineInfoNode[ci.linetable[1]]
122+
opaque_ci.inferred = false
123+
end
112124
opaque_ci
113125
end
114126

@@ -393,12 +405,17 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
393405
code = opaque_ci.code = expand_switch(code, bb_ranges, slot_map)
394406
end
395407

396-
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
408+
@static if VERSION v"1.12.0-DEV.173"
409+
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
410+
debuginfo.def = :var"N/A"
411+
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
412+
else
413+
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
414+
end
397415
opaque_ci.ssavaluetypes = length(code)
398-
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
416+
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
399417
end
400418

401-
402419
for nc = 2:2:n_closures
403420
fwds = Any[nothing for i = 1:length(ir.stmts)]
404421

@@ -475,9 +492,15 @@ function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::I
475492
end
476493
end
477494

478-
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
495+
@static if VERSION v"1.12.0-DEV.173"
496+
debuginfo = Core.Compiler.DebugInfoStream(nothing, opaque_ci.debuginfo, length(code))
497+
debuginfo.def = :var"N/A"
498+
opaque_ci.debuginfo = Core.DebugInfo(debuginfo, length(code))
499+
else
500+
opaque_ci.codelocs = Int32[0 for i=1:length(code)]
501+
end
479502
opaque_ci.ssavaluetypes = length(code)
480-
opaque_ci.ssaflags = UInt8[0 for i=1:length(code)]
503+
opaque_ci.ssaflags = SSAFlagType[zero(SSAFlagType) for i=1:length(code)]
481504
end
482505

483506
# TODO: This is absolutely aweful, but the best we can do given the data structures we have

src/higher_fwd_rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ using Base.Iterators
44

55
function njet(::Val{N}, ::typeof(sin), x₀) where {N}
66
(s, c) = sincos(x₀)
7-
Jet(x₀, s, tuple(take(cycle((c, -s, -c, s)), N)...))
7+
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((c, -s, -c, s)), N)...))
88
end
99

1010
function njet(::Val{N}, ::typeof(cos), x₀) where {N}
1111
(s, c) = sincos(x₀)
12-
Jet(x₀, s, tuple(take(cycle((-s, -c, s, c)), N)...))
12+
Jet(convert(typeof(s), x₀), s, tuple(take(cycle((-s, -c, s, c)), N)...))
1313
end
1414

1515
function njet(::Val{N}, ::typeof(exp), x₀) where {N}
1616
exped = exp(x₀)
17-
Jet(x₀, exped, tuple(take(repeated(exped), N)...))
17+
Jet(convert(typeof(exped), x₀), exped, tuple(take(repeated(exped), N)...))
1818
end
1919

2020
jeval(j, x) = j(x)

src/jet.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function Base.show(io::IO, j::Jet)
104104
end
105105

106106
function domain_check(j::Jet, x)
107-
if j.a !== x
107+
if j.a !== convert(typeof(j.a), x)
108108
throw(DomainError("Evaluation is only valid at a"))
109109
end
110110
end
@@ -153,11 +153,17 @@ function ChainRulesCore.rrule(j::Jet, x)
153153
end
154154

155155
function ChainRulesCore.rrule(::typeof(map), ::typeof(*), a, b)
156-
map(*, a, b), Δ->(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
156+
map(*, a, b), Δ->let Δ=unthunk(Δ)
157+
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent(), NoTangent())
158+
(NoTangent(), NoTangent(), map(*, Δ, b), map(*, a, Δ))
159+
end
157160
end
158161

159162
ChainRulesCore.rrule(::typeof(map), ::typeof(integrate), js::Array{<:Jet}) =
160-
map(integrate, js), Δ->(NoTangent(), NoTangent(), map(deriv, Δ))
163+
map(integrate, js), Δ->let Δ=unthunk(Δ)
164+
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
165+
(NoTangent(), NoTangent(), map(deriv, Δ))
166+
end
161167

162168
struct derivBack
163169
js
@@ -177,7 +183,10 @@ end
177183

178184
function ChainRulesCore.rrule(::typeof(mapev), js::Array{<:Jet}, xs::AbstractArray)
179185
mapev(js, xs), let djs=map(deriv, js)
180-
Δ->(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
186+
function (Δ)
187+
isa(Δ, NoTangent) && return (NoTangent(), NoTangent(), NoTangent())
188+
(NoTangent(), NoTangent(), map(*, unthunk(Δ), mapev(djs, xs)))
189+
end
181190
end
182191
end
183192

src/stage1/compiler_utils.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Utilities that should probably go into Core.Compiler
2-
using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter
1+
# Utilities that should probably go into CC
2+
using .CC: IRCode, CFG, BasicBlock, BBIdxIter
33

44
function Base.push!(cfg::CFG, bb::BasicBlock)
55
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
@@ -8,38 +8,40 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
88
end
99

1010
if VERSION < v"1.11.0-DEV.258"
11-
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
11+
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
1212
end
1313

14-
Base.copy(ir::IRCode) = Core.Compiler.copy(ir)
14+
Base.copy(ir::IRCode) = CC.copy(ir)
1515

16-
Core.Compiler.NewInstruction(@nospecialize node) =
16+
CC.NewInstruction(@nospecialize node) =
1717
NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED)
1818

19-
Base.setproperty!(x::Core.Compiler.Instruction, f::Symbol, v) =
20-
Core.Compiler.setindex!(x, v, f)
19+
Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f)
2120

22-
Base.getproperty(x::Core.Compiler.Instruction, f::Symbol) =
23-
Core.Compiler.getindex(x, f)
21+
Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f)
2422

2523
function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int)
2624
stmt = ir.stmts[i]
2725
stmt.inst = ni.stmt
2826
stmt.type = ni.type
2927
stmt.flag = something(ni.flag, 0) # fixes 1.9?
30-
stmt.line = something(ni.line, 0)
28+
@static if VERSION v"1.12.0-DEV.173"
29+
stmt.line = something(ni.line, CC.NoLineUpdate)
30+
else
31+
stmt.line = something(ni.line, 0)
32+
end
3133
return ni
3234
end
3335

3436
function Base.push!(ir::IRCode, ni::NewInstruction)
3537
# TODO: This should be a check in insert_node!
3638
@assert length(ir.new_nodes.stmts) == 0
37-
@static if isdefined(Core.Compiler, :add!)
39+
@static if isdefined(CC, :add!)
3840
# Julia 1.7 & 1.8
39-
ir[Core.Compiler.add!(ir.stmts)] = ni
41+
ir[CC.add!(ir.stmts)] = ni
4042
else
4143
# Re-named in https://github.com/JuliaLang/julia/pull/47051
42-
ir[Core.Compiler.add_new_idx!(ir.stmts)] = ni
44+
ir[CC.add_new_idx!(ir.stmts)] = ni
4345
end
4446
ir
4547
end
@@ -54,8 +56,8 @@ function Base.iterate(it::Iterators.Reverse{BBIdxIter},
5456
return (bb, idx - 1), (bb, idx - 1)
5557
end
5658

57-
Base.lastindex(x::Core.Compiler.InstructionStream) =
58-
Core.Compiler.length(x)
59+
Base.lastindex(x::CC.InstructionStream) =
60+
CC.length(x)
5961

6062
"""
6163
find_end_of_phi_block(ir::IRCode, start_search_idx::Int)

src/stage1/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function _frule(::NTuple{<:Any, AbstractZero}, f, primal_args...)
126126
end
127127

128128
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
129-
bundles = map(bundle, partials, args)
129+
bundles = map(bundle, args, partials)
130130
result = ∂☆internal{1}()(bundles...)
131131
primal(result), first_partial(result)
132132
end

0 commit comments

Comments
 (0)