diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index a5736ba7e..c14bca747 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.49" +version = "0.6.50" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index b6e6d8aed..a3c55dd5c 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -70,7 +70,7 @@ Moreover, each context type is supported by a specific subset of backends: | `AutoMooncake` | ✅ | ✅ | | `AutoPolyesterForwardDiff` | ✅ | ✅ | | `AutoReverseDiff` | ✅ | ❌ | -| `AutoSymbolics` | ✅ | ❌ | +| `AutoSymbolics` | ✅ | ✅ | | `AutoTracker` | ✅ | ❌ | | `AutoZygote` | ✅ | 🔀 | diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 2dc8d0018..1179861e6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -28,11 +28,23 @@ variablize(::Number, name::Symbol) = variable(name) variablize(x::AbstractArray, name::Symbol) = variables(name, axes(x)...) function variablize(contexts::NTuple{C,DI.Context}) where {C} - map(enumerate(contexts)) do (k, c) + return ntuple(Val(C)) do k + c = contexts[k] variablize(DI.unwrap(c), Symbol("context$k")) end end +function erase_cache_vars!( + context_vars::NTuple{C}, contexts::NTuple{C,DI.Context} +) where {C} + # erase the active data from caches before building function + for (v, c) in zip(context_vars, contexts) + if c isa DI.Cache + fill!(v, zero(eltype(v))) + end + end +end + include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 96450be9a..30c73ce23 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -18,6 +18,7 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(f(x_var + t_var * dx_var, context_vars...), t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) + erase_cache_vars!(context_vars, contexts) res = build_function( pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true ) @@ -104,6 +105,7 @@ function DI.prepare_derivative_nokwarg( context_vars = variablize(contexts) der_var = derivative(f(x_var, context_vars...), x_var) + erase_cache_vars!(context_vars, contexts) res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = if res isa Tuple res @@ -179,6 +181,7 @@ function DI.prepare_gradient_nokwarg( # Symbolic.gradient only accepts vectors grad_var = gradient(f(x_var, context_vars...), vec(x_var)) + erase_cache_vars!(context_vars, contexts) res = build_function( grad_var, vec(x_var), context_vars...; expression=Val(false), cse=true ) @@ -258,6 +261,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(f(x_var, context_vars...), x_var) end + erase_cache_vars!(context_vars, contexts) res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsOneArgJacobianPrep(_sig, jac_exe, jac_exe!) @@ -337,6 +341,7 @@ function DI.prepare_hessian_nokwarg( hessian(f(x_var, context_vars...), vec(x_var)) end + erase_cache_vars!(context_vars, contexts) res = build_function( hess_var, vec(x_var), context_vars...; expression=Val(false), cse=true ) @@ -425,6 +430,7 @@ function DI.prepare_hvp_nokwarg( hess_var = hessian(f(x_var, context_vars...), vec(x_var)) hvp_vec_var = hess_var * vec(dx_var) + erase_cache_vars!(context_vars, contexts) res = build_function( hvp_vec_var, vec(x_var), @@ -519,6 +525,7 @@ function DI.prepare_second_derivative_nokwarg( der_var = derivative(f(x_var, context_vars...), x_var) der2_var = derivative(der_var, x_var) + erase_cache_vars!(context_vars, contexts) res = build_function(der2_var, x_var, context_vars...; expression=Val(false), cse=true) (der2_exe, der2_exe!) = if res isa Tuple res diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 6597af01c..41c80e95b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -26,6 +26,7 @@ function DI.prepare_pushforward_nokwarg( step_der_var = derivative(y_var, t_var) pf_var = substitute(step_der_var, Dict(t_var => zero(eltype(x)))) + erase_cache_vars!(context_vars, contexts) res = build_function( pf_var, x_var, dx_var, context_vars...; expression=Val(false), cse=true ) @@ -116,6 +117,7 @@ function DI.prepare_derivative_nokwarg( f!(y_var, x_var, context_vars...) der_var = derivative(y_var, x_var) + erase_cache_vars!(context_vars, contexts) res = build_function(der_var, x_var, context_vars...; expression=Val(false), cse=true) (der_exe, der_exe!) = res return SymbolicsTwoArgDerivativePrep(_sig, der_exe, der_exe!) @@ -203,6 +205,7 @@ function DI.prepare_jacobian_nokwarg( jacobian(y_var, x_var) end + erase_cache_vars!(context_vars, contexts) res = build_function(jac_var, x_var, context_vars...; expression=Val(false), cse=true) (jac_exe, jac_exe!) = res return SymbolicsTwoArgJacobianPrep(_sig, jac_exe, jac_exe!) diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 91625b700..1ad15e8f1 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -22,7 +22,6 @@ test_differentiation( test_differentiation( AutoSymbolics(), default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false); - excluded=[:jacobian], # TODO: figure out why this fails logging=LOGGING, );