Skip to content

Commit 80660ad

Browse files
authored
Merge pull request #263 from JuliaDiff/ox/eras2
Eras mode
2 parents a68e3f3 + 4b1f94a commit 80660ad

File tree

11 files changed

+367
-160
lines changed

11 files changed

+367
-160
lines changed

src/codegen/forward_demand.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
4040
end
4141
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
4242
val, order::Int;
43-
custom_diff!, diff_cache)
43+
custom_diff!, diff_cache, eras_mode)
4444
return ChainRulesCore.zero_tangent(val)
4545
end
4646
function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
4747
arg::Argument, order::Int;
48-
custom_diff!, diff_cache)
49-
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
48+
custom_diff!, diff_cache, eras_mode)
49+
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache, eras_mode)
5050
val = custom_diff!(ir, SSAValue(0), arg, recurse)
5151
if val !== nothing
5252
return val
@@ -56,9 +56,9 @@ end
5656

5757
function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpretationState,
5858
ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int;
59-
custom_diff!, diff_cache)
59+
custom_diff!, diff_cache, eras_mode)
6060
stmt = inst[:inst]
61-
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
61+
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode)
6262
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
6363
return val
6464
elseif isa(stmt, PiNode)
@@ -212,8 +212,10 @@ Internal method which generates the code for forward mode diffentiation
212212
decides if the custom `transform!` should be applied to a `stmt` or not
213213
Default: `false` for all statements
214214
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
215+
- `eras_mode`: determines if to error if not all derivatives are taylor
215216
"""
216217
function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
218+
eras_mode = false,
217219
visit_custom! = (@nospecialize args...)->false,
218220
transform! = (@nospecialize args...)->error())
219221
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
@@ -286,12 +288,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
286288
newargs = map(stmt.args[2:end]) do @nospecialize arg
287289
maparg(arg, SSAValue(ssa), order)
288290
end
289-
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
291+
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order, eras_mode}(), newargs...))
290292
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
291293
newargs = map(stmt.args) do @nospecialize arg
292294
maparg(arg, SSAValue(ssa), order)
293295
end
294-
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
296+
f = isexpr(stmt, :call) ? ∂☆{order, eras_mode}() : ∂☆new{order}()
295297
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
296298
elseif isa(stmt, PiNode)
297299
# TODO: New PiNode that discriminates based on primal?

src/extra_rules.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,16 @@ function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::M
287287
= setproperty!(ȯbj, field, ẋ)
288288
return y, ẏ
289289
end
290+
291+
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/607
292+
Base.:(==)(x::Number, ::ZeroTangent) = iszero(x)
293+
Base.:(==)(::ZeroTangent, x::Number) = iszero(x)
294+
Base.hash(x::ZeroTangent, h::UInt64) = hash(0, h)
295+
296+
# should this be in ChainRules/ChainRulesCore?
297+
# Avoid making nested backings, a Tangent is already a valid Tangent for a Tangent,
298+
# or a valid second order Tangent for the primal
299+
function ChainRulesCore.frule((_, ẋ), T::Type{<:Tangent}, x)
300+
::Tangent
301+
return T(x), ẋ
302+
end

src/interface.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,28 @@ const ∂⃖¹ = ∂⃖{1}()
2525
(::Type{∂⃖})(args...) = ∂⃖¹(args...)
2626

2727
"""
28-
∂☆{N}
28+
∂☆{N,E}
2929
30-
∂☆{N} is the forward-mode AD functor of order `N`. A call
30+
∂☆{N} is the forward-mode AD functor of order `N` (An integer). A call
3131
`(::∂☆{N})(f, args...)` evaluating a function `f: A -> B` is lifted to its
3232
pushforward on the N-th order tangent bundle `f⋆: Tⁿ A -> Tⁿ B`.
33+
34+
35+
!!!advanced "Eras Mode"
36+
E (a bool, default false) is for Eras mode. In Eras mode, we are Taylor or bust.
37+
Normally if a particular derivative can not be represented as a `TaylorBundle`
38+
we fall back and represent it as a `ExplictTangentBundle`.
39+
However, in Eras mode we error if it can't be represented as a TaylorBundle.
40+
In general, this is not wanted since it often will break nested AD.
41+
But in the cases it doesn't its really fast, since it means we can rewrite nested AD
42+
as Taylor-mode AD (plus its more type stable).
43+
To be safe in Eras mode, it is sufficient, but not necessary, to be doing nested AD with
44+
respect to the same variable. It also works in other cases where (likely by problem construction)
45+
ADing with respect to a second variable happens to result in something that can be represented
46+
with a `TaylorBundle` also. (You need your different partials to happen to be exactly equal).
3347
"""
34-
struct ∂☆{N}; end
48+
struct ∂☆{N, E}; end
49+
∂☆{N}() where N = ∂☆{N,false}() # default to not using Era mode
3550
const ∂☆¹ = ∂☆{1}()
3651

3752
"""

src/stage1/broadcast.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@ using Base.Broadcast
22
using Base.Broadcast: broadcasted, Broadcasted
33

44
# Forward mode broadcast rule
5-
struct FwdBroadcast{N, T<:AbstractTangentBundle{N}}
5+
struct FwdBroadcast{N, E, T<:AbstractTangentBundle{N}}
66
f::T
77
end
8-
(f::FwdBroadcast{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
8+
FwdBroadcast{E}(f::T) where {N, E, T<:AbstractTangentBundle{N}} = FwdBroadcast{N,E,T}(f)
9+
10+
(f::FwdBroadcast{N,E})(args::AbstractTangentBundle{N}...) where {N,E} = ∂☆{N,E}()(f.f, args...)
911

1012
n_getfield(∂ₙ::∂☆{N}, b::ATB{N}, x::Union{Symbol, Int}) where {N} = ∂ₙ(ZeroBundle{N}(getfield), b, ZeroBundle{N}(x))
1113

12-
function (∂ₙ::∂☆{N})(zc::AbstractZeroBundle{N, typeof(copy)},
13-
bc::ATB{N, <:Broadcasted}) where {N}
14+
function (∂ₙ::∂☆{N,E})(zc::AbstractZeroBundle{N, typeof(copy)},
15+
bc::ATB{N, <:Broadcasted}) where {N,E}
1416
bc = ∂ₙ(ZeroBundle{N}(Broadcast.flatten), bc)
1517
args = n_getfield(∂ₙ, bc, :args)
1618
r = copy(Broadcasted(
17-
FwdMap(n_getfield(∂ₙ, bc, :f)),
19+
FwdMap{E}(n_getfield(∂ₙ, bc, :f)),
1820
ntuple(length(primal(args))) do i
1921
val = n_getfield(∂ₙ, args, i)
2022
if ndims(primal(val)) == 0

0 commit comments

Comments
 (0)