Skip to content

ForwardDiff overloads #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ChrisRackauckas opened this issue May 12, 2025 · 8 comments
Open

ForwardDiff overloads #607

ChrisRackauckas opened this issue May 12, 2025 · 8 comments
Assignees

Comments

@ChrisRackauckas
Copy link
Member

A(p)x = b(p) does not need to differentiate w.r.t. p, since dA/dp dx/dp = db/dp is just another linear equation, so that can be solved. Thus the duals should just be split and it should solve using the same solver.

The implementation can follow the same setup. You find cases with Dual numbers:

https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L22-L32

If you catch that, then you remake using ForwardDiff.value to de-dual and solve the non-dual part:

https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L43-L59

and then you push forward the partials:

https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L64-L78

It's actually the same thing as this code, since f(x,p) = 0 when x(p) is the solution has the derivative f_x x_p + f_p = 0 which means x_p = - f_x \ f_p which is this line https://github.com/SciML/NonlinearSolve.jl/blob/v4.8.0/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl#L66. So it's effectively the same thing as this code right here is just turned into x_p = A_p \ b_p instead, i.e. just grabbing the partials out of A and b.

This is the core of the downstream issue SciML/ModelingToolkit.jl#3589

@jClugstor
Copy link
Member

Essentially, this is a way of avoiding having the Dual numbers go through the actual solver?

If I'm thinking about this correctly, we only need to catch the case where typeof(A) <: AbstractSciMLOperator, and typeof(p) is <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}, right?

@oscardssmith
Copy link
Member

We want to do this for all A.

@jClugstor
Copy link
Member

But p is only every used if A is a SciMLOperator, right? So if it's not then A_p should just be 0?

Also, can b actually depend on p? I don't see anything like that in the LinearSolve docs.

@ChrisRackauckas
Copy link
Member Author

Also, can b actually depend on p? I don't see anything like that in the LinearSolve docs.

Yes it cannot. But with autodiff, you're instead calculating dx/dA and dx/db.

@jClugstor
Copy link
Member

But with autodiff, you're instead calculating dx/dA and dx/db.

Only if A and b are Duals / arrays of Duals when using ForwardDiff, right?

Oh, I think I get it, for some reason I was thinking that this was only for when p was Dual.

So the point is that it should also go through the overload if we're differentiating wrt. A or b?

If we have

A = rand(3,3)
g(b) = solve(LinearProblem(A, b)).u
g(rand(3))
ForwardDiff.derivative(g, rand(3))

then that should also go through the overload?

@oscardssmith
Copy link
Member

yeah

@jClugstor
Copy link
Member

I think we actually have to use the product rule here (excuse my handwriting):

Image

where the bottom term is where we can do a linear solve to find dx/dp.

I tried it out and this way agrees with ForwardDiff.derivative

using ForwardDiff, LinearSolve
using ForwardDiff: Dual

partial_vals(x::Dual) = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)


function g(p) 
    A = [p p+1 p^3; 
         3*p p+5 p-4; 
         p^2 9*p p]

    b = [p+1,p*2,p^2]
    prob = LinearProblem(A, b)
    sol = solve(prob)
    sol.u
end

g(5.0)

ForwardDiff.derivative(g, 5.0)

julia>  3-element Vector{Float64}:
 -0.03270492074729986
  0.07804258495644212
 -0.009624980644422148

#---------------------------------------------------------------------------------------------------
h(p) = (A = [p p+1 p^3;
    3*p p+5 p-4;
    p^2 9*p p],
b = [p + 1, p * 2, p^2])

A, b = h(ForwardDiff.Dual(5.0, 1.0))

new_A = nodual_value(A)
new_b = nodual_value(b)

sol = solve(LinearProblem(new_A, new_b))

uu = sol.u

∂_A = partial_vals(A)
∂_b = partial_vals(b)

[x.values[1] for x in ∂_A] \ [x.values[1] for x in ∂_b]

x_p = new_A \ [x.values[1] for x in (∂_b - ∂_A*uu)]

julia>  3-element Vector{Float64}:
 -0.03270492074729983
  0.07804258495644215
 -0.00962498064442215

Of course we'll have to make sure that it works when there's more than one partial, but I think that's just more linear solves??

And of course this doesn't work when A is a SciMLOperator, but I guess in that case dA will just be a matter of ForwardDiff.

@ChrisRackauckas
Copy link
Member Author

but I think that's just more linear solves??

yes because those are just columns. You can think of a multi partial dual as just a matrix, so it's the same as A' \ B where B is the matrix of partials.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants