@@ -1258,6 +1258,88 @@ function DiffEqBase._concrete_solve_adjoint(
12581258 p)
12591259end
12601260
1261+ function DiffEqBase. _concrete_solve_adjoint (
1262+ prob:: Union {SciMLBase. AbstractDiscreteProblem,
1263+ SciMLBase. AbstractODEProblem,
1264+ SciMLBase. AbstractDAEProblem,
1265+ SciMLBase. AbstractDDEProblem,
1266+ SciMLBase. AbstractSDEProblem,
1267+ SciMLBase. AbstractSDDEProblem,
1268+ SciMLBase. AbstractRODEProblem
1269+ },
1270+ alg, sensealg:: EnzymeAdjoint ,
1271+ u0, p, originator:: SciMLBase.ADOriginator ,
1272+ args... ; kwargs... )
1273+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
1274+ du0 = make_zero (u0)
1275+ dp = make_zero (p)
1276+ mode = sensealg. mode
1277+
1278+ f = (u0, p) -> solve (prob, alg, args... ; u0 = u0, p = p,
1279+ sensealg = SensitivityADPassThrough (),
1280+ kwargs_filtered... )
1281+
1282+ splitmode = if mode isa Forward
1283+ error (" EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary." )
1284+ elseif mode === nothing || mode === Reverse
1285+ ReverseSplitWithPrimal
1286+ end
1287+
1288+ forward, reverse = autodiff_thunk (splitmode, Const{typeof (f)}, Duplicated, Duplicated{typeof (u0)}, Duplicated{typeof (p)})
1289+ tape, result, shadow_result = forward (Const (f), Duplicated (u0, du0), Duplicated (p, dp))
1290+
1291+ function enzyme_sensitivity_backpass (Δ)
1292+ reverse (Const (f), Duplicated (u0, du0), Duplicated (p, dp), Δ, tape)
1293+ if originator isa SciMLBase. TrackerOriginator ||
1294+ originator isa SciMLBase. ReverseDiffOriginator
1295+ (NoTangent (), NoTangent (), du0, dp, NoTangent (),
1296+ ntuple (_ -> NoTangent (), length (args))... )
1297+ else
1298+ (NoTangent (), NoTangent (), NoTangent (), du0, dp, NoTangent (),
1299+ ntuple (_ -> NoTangent (), length (args))... )
1300+ end
1301+ end
1302+ sol, enzyme_sensitivity_backpass
1303+ end
1304+
1305+ # NOTE: This is needed to prevent a method ambiguity error
1306+ function DiffEqBase. _concrete_solve_adjoint (
1307+ prob:: AbstractNonlinearProblem , alg, sensealg:: EnzymeAdjoint ,
1308+ u0, p, originator:: SciMLBase.ADOriginator ,
1309+ args... ; kwargs... )
1310+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
1311+
1312+ du0 = make_zero (u0)
1313+ dp = make_zero (p)
1314+ mode = sensealg. mode
1315+
1316+ f = (u0, p) -> solve (prob, alg, args... ; u0 = u0, p = p,
1317+ sensealg = SensitivityADPassThrough (),
1318+ kwargs_filtered... )
1319+
1320+ splitmode = if mode isa Forward
1321+ error (" EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary." )
1322+ elseif mode === nothing || mode === Reverse
1323+ ReverseSplitWithPrimal
1324+ end
1325+
1326+ forward, reverse = autodiff_thunk (splitmode, Const{typeof (f)}, Duplicated, Duplicated{typeof (u0)}, Duplicated{typeof (p)})
1327+ tape, result, shadow_result = forward (Const (f), Duplicated (u0, du0), Duplicated (p, dp))
1328+
1329+ function enzyme_sensitivity_backpass (Δ)
1330+ reverse (Const (f), Duplicated (u0, du0), Duplicated (p, dp), Δ, tape)
1331+ if originator isa SciMLBase. TrackerOriginator ||
1332+ originator isa SciMLBase. ReverseDiffOriginator
1333+ (NoTangent (), NoTangent (), du0, dp, NoTangent (),
1334+ ntuple (_ -> NoTangent (), length (args))... )
1335+ else
1336+ (NoTangent (), NoTangent (), NoTangent (), du0, dp, NoTangent (),
1337+ ntuple (_ -> NoTangent (), length (args))... )
1338+ end
1339+ end
1340+ sol, enzyme_sensitivity_backpass
1341+ end
1342+
12611343const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
12621344 `Enzyme` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
12631345 Either choose a different adjoint method like `GaussAdjoint`,
0 commit comments