@@ -1258,6 +1258,88 @@ function DiffEqBase._concrete_solve_adjoint(
1258
1258
p)
1259
1259
end
1260
1260
1261
+ function DiffEqBase. _concrete_solve_adjoint (
1262
+ prob:: Union {SciMLBase. AbstractDiscreteProblem,
1263
+ SciMLBase. AbstractODEProblem,
1264
+ SciMLBase. AbstractDAEProblem,
1265
+ SciMLBase. AbstractDDEProblem,
1266
+ SciMLBase. AbstractSDEProblem,
1267
+ SciMLBase. AbstractSDDEProblem,
1268
+ SciMLBase. AbstractRODEProblem
1269
+ },
1270
+ alg, sensealg:: EnzymeAdjoint ,
1271
+ u0, p, originator:: SciMLBase.ADOriginator ,
1272
+ args... ; kwargs... )
1273
+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
1274
+ du0 = make_zero (u0)
1275
+ dp = make_zero (p)
1276
+ mode = sensealg. mode
1277
+
1278
+ f = (u0, p) -> solve (prob, alg, args... ; u0 = u0, p = p,
1279
+ sensealg = SensitivityADPassThrough (),
1280
+ kwargs_filtered... )
1281
+
1282
+ splitmode = if mode isa Forward
1283
+ error (" EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary." )
1284
+ elseif mode === nothing || mode === Reverse
1285
+ ReverseSplitWithPrimal
1286
+ end
1287
+
1288
+ forward, reverse = autodiff_thunk (splitmode, Const{typeof (f)}, Duplicated, Duplicated{typeof (u0)}, Duplicated{typeof (p)})
1289
+ tape, result, shadow_result = forward (Const (f), Duplicated (u0, du0), Duplicated (p, dp))
1290
+
1291
+ function enzyme_sensitivity_backpass (Δ)
1292
+ reverse (Const (f), Duplicated (u0, du0), Duplicated (p, dp), Δ, tape)
1293
+ if originator isa SciMLBase. TrackerOriginator ||
1294
+ originator isa SciMLBase. ReverseDiffOriginator
1295
+ (NoTangent (), NoTangent (), du0, dp, NoTangent (),
1296
+ ntuple (_ -> NoTangent (), length (args))... )
1297
+ else
1298
+ (NoTangent (), NoTangent (), NoTangent (), du0, dp, NoTangent (),
1299
+ ntuple (_ -> NoTangent (), length (args))... )
1300
+ end
1301
+ end
1302
+ sol, enzyme_sensitivity_backpass
1303
+ end
1304
+
1305
+ # NOTE: This is needed to prevent a method ambiguity error
1306
+ function DiffEqBase. _concrete_solve_adjoint (
1307
+ prob:: AbstractNonlinearProblem , alg, sensealg:: EnzymeAdjoint ,
1308
+ u0, p, originator:: SciMLBase.ADOriginator ,
1309
+ args... ; kwargs... )
1310
+ kwargs_filtered = NamedTuple (filter (x -> x[1 ] != :sensealg , kwargs))
1311
+
1312
+ du0 = make_zero (u0)
1313
+ dp = make_zero (p)
1314
+ mode = sensealg. mode
1315
+
1316
+ f = (u0, p) -> solve (prob, alg, args... ; u0 = u0, p = p,
1317
+ sensealg = SensitivityADPassThrough (),
1318
+ kwargs_filtered... )
1319
+
1320
+ splitmode = if mode isa Forward
1321
+ error (" EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary." )
1322
+ elseif mode === nothing || mode === Reverse
1323
+ ReverseSplitWithPrimal
1324
+ end
1325
+
1326
+ forward, reverse = autodiff_thunk (splitmode, Const{typeof (f)}, Duplicated, Duplicated{typeof (u0)}, Duplicated{typeof (p)})
1327
+ tape, result, shadow_result = forward (Const (f), Duplicated (u0, du0), Duplicated (p, dp))
1328
+
1329
+ function enzyme_sensitivity_backpass (Δ)
1330
+ reverse (Const (f), Duplicated (u0, du0), Duplicated (p, dp), Δ, tape)
1331
+ if originator isa SciMLBase. TrackerOriginator ||
1332
+ originator isa SciMLBase. ReverseDiffOriginator
1333
+ (NoTangent (), NoTangent (), du0, dp, NoTangent (),
1334
+ ntuple (_ -> NoTangent (), length (args))... )
1335
+ else
1336
+ (NoTangent (), NoTangent (), NoTangent (), du0, dp, NoTangent (),
1337
+ ntuple (_ -> NoTangent (), length (args))... )
1338
+ end
1339
+ end
1340
+ sol, enzyme_sensitivity_backpass
1341
+ end
1342
+
1261
1343
const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """
1262
1344
`Enzyme` is not compatible with `ReverseDiffAdjoint` nor with `TrackerAdjoint`.
1263
1345
Either choose a different adjoint method like `GaussAdjoint`,
0 commit comments