@@ -40,13 +40,13 @@ function forward_diff!(ir::IRCode, interp::AbstractInterpreter, irsv::IRInterpre
40
40
end
41
41
function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
42
42
val, order:: Int ;
43
- custom_diff!, diff_cache)
43
+ custom_diff!, diff_cache, eras_mode )
44
44
return ChainRulesCore. zero_tangent (val)
45
45
end
46
46
function forward_diff! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
47
47
arg:: Argument , order:: Int ;
48
- custom_diff!, diff_cache)
49
- recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache)
48
+ custom_diff!, diff_cache, eras_mode )
49
+ recurse (x) = forward_diff! (ir, interp, irsv, x; custom_diff!, diff_cache, eras_mode )
50
50
val = custom_diff! (ir, SSAValue (0 ), arg, recurse)
51
51
if val != = nothing
52
52
return val
56
56
57
57
function forward_diff_uncached! (ir:: IRCode , interp:: AbstractInterpreter , irsv:: IRInterpretationState ,
58
58
ssa:: SSAValue , inst:: Core.Compiler.Instruction , order:: Int ;
59
- custom_diff!, diff_cache)
59
+ custom_diff!, diff_cache, eras_mode )
60
60
stmt = inst[:inst ]
61
- recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache)
61
+ recurse (x) = forward_diff! (ir, interp, irsv, x, order; custom_diff!, diff_cache, eras_mode )
62
62
if (val = custom_diff! (ir, ssa, stmt, recurse)) != = nothing
63
63
return val
64
64
elseif isa (stmt, PiNode)
@@ -212,8 +212,10 @@ Internal method which generates the code for forward mode diffentiation
212
212
decides if the custom `transform!` should be applied to a `stmt` or not
213
213
Default: `false` for all statements
214
214
- `transform!(ir::IRCode, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
215
+ - `eras_mode`: determines if to error if not all derivatives are taylor
215
216
"""
216
217
function forward_diff_no_inf! (ir:: IRCode , to_diff:: Vector{Pair{SSAValue,Int}} ;
218
+ eras_mode = false ,
217
219
visit_custom! = (@nospecialize args... )-> false ,
218
220
transform! = (@nospecialize args... )-> error ())
219
221
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
@@ -286,12 +288,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
286
288
newargs = map (stmt. args[2 : end ]) do @nospecialize arg
287
289
maparg (arg, SSAValue (ssa), order)
288
290
end
289
- replace_call! (ir, SSAValue (ssa), Expr (:call , ∂☆ {order} (), newargs... ))
291
+ replace_call! (ir, SSAValue (ssa), Expr (:call , ∂☆ {order, eras_mode } (), newargs... ))
290
292
elseif isexpr (stmt, :call ) || isexpr (stmt, :new )
291
293
newargs = map (stmt. args) do @nospecialize arg
292
294
maparg (arg, SSAValue (ssa), order)
293
295
end
294
- f = isexpr (stmt, :call ) ? ∂☆ {order} () : ∂☆new {order} ()
296
+ f = isexpr (stmt, :call ) ? ∂☆ {order, eras_mode } () : ∂☆new {order} ()
295
297
replace_call! (ir, SSAValue (ssa), Expr (:call , f, newargs... ))
296
298
elseif isa (stmt, PiNode)
297
299
# TODO : New PiNode that discriminates based on primal?
0 commit comments