Skip to content

Commit f6ded5b

Browse files
WIP add EnzymeAdjoint
Fixes #1148. Needs tests, and it needs the actual direct adjoints of Enzyme to work again
1 parent 0722333 commit f6ded5b

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

src/SciMLSensitivity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityF
9595
shadow_forward, shadow_adjoint
9696

9797
export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint,
98-
TrackerAdjoint, ZygoteAdjoint, ReverseDiffAdjoint,
98+
TrackerAdjoint, ZygoteAdjoint, EnzymeAdjoint, ReverseDiffAdjoint,
9999
ForwardSensitivity, ForwardDiffSensitivity,
100100
ForwardDiffOverAdjoint,
101101
SteadyStateAdjoint,

src/concrete_solve.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,88 @@ function DiffEqBase._concrete_solve_adjoint(
12581258
p)
12591259
end
12601260

1261+
function DiffEqBase._concrete_solve_adjoint(
1262+
prob::Union{SciMLBase.AbstractDiscreteProblem,
1263+
SciMLBase.AbstractODEProblem,
1264+
SciMLBase.AbstractDAEProblem,
1265+
SciMLBase.AbstractDDEProblem,
1266+
SciMLBase.AbstractSDEProblem,
1267+
SciMLBase.AbstractSDDEProblem,
1268+
SciMLBase.AbstractRODEProblem
1269+
},
1270+
alg, sensealg::EnzymeAdjoint,
1271+
u0, p, originator::SciMLBase.ADOriginator,
1272+
args...; kwargs...)
1273+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
1274+
du0 = make_zero(u0)
1275+
dp = make_zero(p)
1276+
mode = sensealg.mode
1277+
1278+
f = (u0, p) -> solve(prob, alg, args...; u0 = u0, p = p,
1279+
sensealg = SensitivityADPassThrough(),
1280+
kwargs_filtered...)
1281+
1282+
splitmode = if mode isa Forward
1283+
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
1284+
elseif mode === nothing || mode === Reverse
1285+
ReverseSplitWithPrimal
1286+
end
1287+
1288+
forward, reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, Duplicated{typeof(u0)}, Duplicated{typeof(p)})
1289+
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))
1290+
1291+
function enzyme_sensitivity_backpass(Δ)
1292+
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
1293+
if originator isa SciMLBase.TrackerOriginator ||
1294+
originator isa SciMLBase.ReverseDiffOriginator
1295+
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
1296+
ntuple(_ -> NoTangent(), length(args))...)
1297+
else
1298+
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
1299+
ntuple(_ -> NoTangent(), length(args))...)
1300+
end
1301+
end
1302+
sol, enzyme_sensitivity_backpass
1303+
end
1304+
1305+
# NOTE: This is needed to prevent a method ambiguity error
1306+
function DiffEqBase._concrete_solve_adjoint(
1307+
prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint,
1308+
u0, p, originator::SciMLBase.ADOriginator,
1309+
args...; kwargs...)
1310+
kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs))
1311+
1312+
du0 = make_zero(u0)
1313+
dp = make_zero(p)
1314+
mode = sensealg.mode
1315+
1316+
f = (u0, p) -> solve(prob, alg, args...; u0 = u0, p = p,
1317+
sensealg = SensitivityADPassThrough(),
1318+
kwargs_filtered...)
1319+
1320+
splitmode = if mode isa Forward
1321+
error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.")
1322+
elseif mode === nothing || mode === Reverse
1323+
ReverseSplitWithPrimal
1324+
end
1325+
1326+
forward, reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, Duplicated{typeof(u0)}, Duplicated{typeof(p)})
1327+
tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp))
1328+
1329+
function enzyme_sensitivity_backpass(Δ)
1330+
reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape)
1331+
if originator isa SciMLBase.TrackerOriginator ||
1332+
originator isa SciMLBase.ReverseDiffOriginator
1333+
(NoTangent(), NoTangent(), du0, dp, NoTangent(),
1334+
ntuple(_ -> NoTangent(), length(args))...)
1335+
else
1336+
(NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(),
1337+
ntuple(_ -> NoTangent(), length(args))...)
1338+
end
1339+
end
1340+
sol, enzyme_sensitivity_backpass
1341+
end
1342+
12611343
const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
12621344
`Enzyme` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
12631345
Either choose a different adjoint method like `GaussAdjoint`,

src/sensitivity_algorithms.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,34 @@ Currently fails on almost every solver.
648648
"""
649649
struct ZygoteAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} end
650650

651+
"""
652+
EnzymeAdjoint <: AbstractAdjointSensitivityAlgorithm{nothing,true,nothing}
653+
654+
An implementation of discrete adjoint sensitivity analysis
655+
using the Enzyme.jl source-to-source AD directly on the differential equation
656+
solver.
657+
658+
## Constructor
659+
660+
```julia
661+
EnzymeAdjoint(mode = nothing)
662+
```
663+
664+
## Arugments
665+
666+
* `mode::M` determines the autodiff mode (forward or reverse). It can be:
667+
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
668+
+ `nothing` to choose the best mode automatically
669+
670+
## SciMLProblem Support
671+
672+
Currently fails on almost every solver.
673+
"""
674+
struct EnzymeAdjoint{M <: Union{Nothing,EnzymeCore.Mode}} <: AbstractAdjointSensitivityAlgorithm{nothing, true, nothing}
675+
mode::M
676+
EnzymeAdjoint(mode = nothing) = new(mode)
677+
end
678+
651679
"""
652680
```julia
653681
ForwardLSS{CS, AD, FDT, RType, gType} <: AbstractShadowingSensitivityAlgorithm{CS, AD, FDT}

0 commit comments

Comments
 (0)