-
-
Notifications
You must be signed in to change notification settings - Fork 61
ForwardDiff overloads #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Essentially, this is a way of avoiding having the Dual numbers go through the actual solver? If I'm thinking about this correctly, we only need to catch the case where |
We want to do this for all |
But Also, can |
Yes it cannot. But with autodiff, you're instead calculating |
Only if Oh, I think I get it, for some reason I was thinking that this was only for when So the point is that it should also go through the overload if we're differentiating wrt. If we have A = rand(3,3)
g(b) = solve(LinearProblem(A, b)).u
g(rand(3))
ForwardDiff.derivative(g, rand(3)) then that should also go through the overload? |
yeah |
I think we actually have to use the product rule here (excuse my handwriting): where the bottom term is where we can do a linear solve to find I tried it out and this way agrees with using ForwardDiff, LinearSolve
using ForwardDiff: Dual
partial_vals(x::Dual) = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
function g(p)
A = [p p+1 p^3;
3*p p+5 p-4;
p^2 9*p p]
b = [p+1,p*2,p^2]
prob = LinearProblem(A, b)
sol = solve(prob)
sol.u
end
g(5.0)
ForwardDiff.derivative(g, 5.0)
julia> 3-element Vector{Float64}:
-0.03270492074729986
0.07804258495644212
-0.009624980644422148
#---------------------------------------------------------------------------------------------------
h(p) = (A = [p p+1 p^3;
3*p p+5 p-4;
p^2 9*p p],
b = [p + 1, p * 2, p^2])
A, b = h(ForwardDiff.Dual(5.0, 1.0))
new_A = nodual_value(A)
new_b = nodual_value(b)
sol = solve(LinearProblem(new_A, new_b))
uu = sol.u
∂_A = partial_vals(A)
∂_b = partial_vals(b)
[x.values[1] for x in ∂_A] \ [x.values[1] for x in ∂_b]
x_p = new_A \ [x.values[1] for x in (∂_b - ∂_A*uu)]
julia> 3-element Vector{Float64}:
-0.03270492074729983
0.07804258495644215
-0.00962498064442215 Of course we'll have to make sure that it works when there's more than one partial, but I think that's just more linear solves?? And of course this doesn't work when A is a SciMLOperator, but I guess in that case dA will just be a matter of ForwardDiff. |
yes because those are just columns. You can think of a multi partial dual as just a matrix, so it's the same as A' \ B where B is the matrix of partials. |
A(p)x = b(p)
does not need to differentiate w.r.t.p
, sincedA/dp dx/dp = db/dp
is just another linear equation, so that can be solved. Thus the duals should just be split and it should solve using the same solver.The implementation can follow the same setup. You find cases with Dual numbers:
https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L22-L32
If you catch that, then you
remake
usingForwardDiff.value
to de-dual and solve the non-dual part:https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L43-L59
and then you push forward the partials:
https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L64-L78
It's actually the same thing as this code, since
f(x,p) = 0
whenx(p)
is the solution has the derivativef_x x_p + f_p = 0
which meansx_p = - f_x \ f_p
which is this line https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L66. So it's effectively the same thing as this code right here is just turned intox_p = A_p \ b_p
instead, i.e. just grabbing the partials out ofA
andb
.This is the core of the downstream issue SciML/ModelingToolkit.jl#3589
The text was updated successfully, but these errors were encountered: