Skip to content

EnzymeAdjoint #1148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ChrisRackauckas opened this issue Nov 12, 2024 · 5 comments · May be fixed by #1209
Open

EnzymeAdjoint #1148

ChrisRackauckas opened this issue Nov 12, 2024 · 5 comments · May be fixed by #1209

Comments

@ChrisRackauckas
Copy link
Member

Now that direct adjoints are starting to work with Enzyme over OrdinaryDiffEq.jl, it would make sense to add this to the SciMLSensitivity.jl system.

using Enzyme, OrdinaryDiffEq, StaticArrays

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

That's a working demonstration. Now what we need is just an EnzymeAdjoint struct which then does exactly that internally: https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/concrete_solve.jl#L1222-L1405.

Better Support for EnzymeAdjoint inside an Enzyme Diff

Now that version is great for a user which defines a loss function with Zygote, but then does sensealg=EnzymeAdjoint() and we take care of the hard ODE part. But if the user uses Enzyme for the loss function and differentiates the ODE, we should somehow detect this case and completely remove it from being hitting the SciMLSensitivity path in the DiffEqBase. Basically if sensealg=EnzymeAdjoint() and in an Enzyme environment, solve should then just switch to sensealg = DiffEqBase.SensitivityADPassThrough(). That said, I don't know how to detect the "in an Enzyme environment", so I don't know how to pull this off. @wsmoses it would be helpful to know how to do this. If this is done then I think we get some extra speed bonuses since then there's no rules used at all in this case.

Supporting EnzymeAdjoint for SDEs

It's probably the same steps as what was required for ODEs, which was:

Since both use the same fastpow, that should already be handled. The SDE integrator type does not use FSAL, https://github.com/SciML/StochasticDiffEq.jl/blob/master/src/integrators/type.jl, so that PR isn't handled. Which means only SciML/OrdinaryDiffEq.jl#2390 is the same issue.

But SciML/OrdinaryDiffEq.jl#2390 was a workaround for a bug in Enzyme, which is maybe fixed now? (@wsmoses). So it's worth just giving direct Enzyme a try. To do it, you'd put it into a mode that force it to ignore the SciMLSensitivity adjoint rules, which is what the ODE code above is doing there. We'd just need an SDE test case like:

using Enzyme, StochasticDiffEq, StaticArrays

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

function lorenz_noise!(du, u, p, t)
  du .= 0.1u
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = SDEProblem{true}(lorenz!, lorenz_noise!, u0, tspan)
    sol = DiffEqBase.solve(prob, EM(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

I haven't ran that to see how it does, but it might just work now.

@wsmoses
Copy link

wsmoses commented Nov 13, 2024

@ChrisRackauckas
Copy link
Member Author

Trying to make this work for some stiff ODE solvers now:

Enzyme: Non-constant keyword argument found for Tuple{UInt64, typeof(Core.kwcall), Duplicated{@NamedTuple{alias_A::Bool, alias_b::Bool, Pl::LinearSolve.InvPreconditioner{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Pr::LinearAlgebra.Diagonal{Float64, Vector{Float64}}, assumptions::LinearSolve.OperatorAssumptions{Bool}}}, typeof(EnzymeCore.EnzymeRules.augmented_primal), EnzymeCore.EnzymeRules.RevConfigWidth{1, true, true, (false, false, false), false}, Const{typeof(init)}, Type{Duplicated{LinearSolve.LinearCache{Matrix{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, LinearSolve.DefaultLinearSolver, LinearSolve.DefaultLinearSolverInit{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}, Matrix{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Int64}}, Nothing, Nothing, Nothing, LinearAlgebra.SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int32}}, Base.RefValue{Int32}}, Tuple{LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Base.RefValue{Int64}}, LinearAlgebra.QRPivoted{Float64, Matrix{Float64}, Vector{Float64}, Vector{Int64}}, Nothing, Nothing}, LinearSolve.InvPreconditioner{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, LinearAlgebra.Diagonal{Float64, Vector{Float64}}, Float64, Bool, LinearSolve.LinearSolveAdjoint{Missing}}}}, Duplicated{LinearProblem{Vector{Float64}, true, Matrix{Float64}, Vector{Float64}, SciMLBase.NullParameters, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}}, Const{Nothing}}

on MWE:

using OrdinaryDiffEq, Enzyme, StaticArrays

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Rodas5P(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

The issue is that Pl::LinearSolve.InvPreconditioner{LinearAlgebra.Diagonal{Float64, Vector{Float64}}}, Pr::LinearAlgebra.Diagonal{Float64, Vector{Float64}} contain mutable data in there, but the solution of a linear system is independent of these pieces. How do I declare that it should treat those as const?

@wsmoses
Copy link

wsmoses commented Nov 13, 2024

I think at this point it's mostly a matter of syntax/macro design. Basically we need a registration system that tells Enzyme that a given argument of a method is inactive (we have the backend infra all setup, but need a user-accessible way to pass the info)

@ChrisRackauckas
Copy link
Member Author

SciML/LinearSolve.jl#382 is a very related issue, where sometimes an algorithm just happens to have an array as the way it takes in the arguments, but Enzyme interprets Array = Differentiable but the adjoint rule knows to ignore it, but those two facts seem to clash.

@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented May 21, 2025

@wsmoses the starter code at the top that was working now seems to be failing (on v1.10):

using Enzyme, OrdinaryDiffEq, StaticArrays

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));
julia> Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));
ERROR: Error handling recursive stores for String which has a fieldcount of 0
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] create_recursive_stores(B::LLVM.IRBuilder, Ty::DataType, prev::LLVM.Value)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:622
  [3] create_recursive_stores(B::LLVM.IRBuilder, Ty::DataType, prev::LLVM.Value) (repeats 2 times)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:667
  [4] shadow_alloc_rewrite(V::Ptr{…}, gutils::Ptr{…}, Orig::Ptr{…}, idx::UInt64, prev::Ptr{…}, used::UInt8)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:739
  [5] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\api.jl:269
  [6] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:1754
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:4669
  [8] codegen
    @ C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:3455 [inlined]
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:5533
 [10] _thunk
    @ C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:5533 [inlined]
 [11] cached_compilation
    @ C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:5585 [inlined]
 [12] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})       
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:5696
 [13] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\compiler.jl:5881
 [14] autodiff
    @ C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\Enzyme.jl:486 [inlined]
 [15] autodiff
    @ C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\Enzyme.jl:545 [inlined]
 [16] autodiff(::ReverseMode{…}, ::typeof(f), ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme C:\Users\accou\.julia\packages\Enzyme\hu9gq\src\Enzyme.jl:517
 [17] top-level scope
    @ REPL[15]:1
Some type information was truncated. Use `show(err)` to see complete types.

The core code being differentiated hasn't changed in 9 months https://github.com/SciML/OrdinaryDiffEq.jl/tree/master/lib/OrdinaryDiffEqTsit5/src, so it looks like an Enzyme regression. However, I can't get any line numbers out of this. How do I get a usable stack trace so it can tell me what string it's so worried about? I'm sure I can just mark that string as something to not differentiate, but I don't have a line number so I'm just stuck here.

ChrisRackauckas added a commit that referenced this issue May 21, 2025
Fixes #1148. Needs tests, and it needs the actual direct adjoints of Enzyme to work again
@ChrisRackauckas ChrisRackauckas linked a pull request May 21, 2025 that will close this issue
ChrisRackauckas added a commit that referenced this issue May 23, 2025
Fixes #1148. Needs tests, and it needs the actual direct adjoints of Enzyme to work again
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants