Skip to content

Overloads for LinearProblems with ForwardDiff Dual numbers #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e"
FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -53,6 +54,7 @@ LinearSolveCUDSSExt = "CUDSS"
LinearSolveEnzymeExt = "EnzymeCore"
LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices"
LinearSolveFastLapackInterfaceExt = "FastLapackInterface"
LinearSolveForwardDiffExt = "ForwardDiff"
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
Expand Down
240 changes: 240 additions & 0 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
module LinearSolveForwardDiffExt

using LinearSolve
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Partials
using SciMLBase
using RecursiveArrayTools

const DualLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
} where {iip, T, V, P}

const DualALinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{Number, <:AbstractArray},
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
} where {iip, T, V, P}

const DualBLinearProblem = LinearProblem{
<:Union{Number, <:AbstractArray, Nothing},
iip,
<:Union{Number, <:AbstractArray},
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
<:Union{Number, <:AbstractArray, SciMLBase.NullParameters}
} where {iip, T, V, P}

const DualAbstractLinearProblem = Union{
DualLinearProblem, DualALinearProblem, DualBLinearProblem}

LinearSolve.@concrete mutable struct DualLinearCache
linear_cache
prob
alg
dual_u0
partials_A
partials_b
end

function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
uu = sol.u

primal_sol = deepcopy(sol)

# Solves Dual partials separately
∂_A = cache.partials_A
∂_b = cache.partials_b
dual_u0 = only(partials_to_list(cache.dual_u0))

rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)

new_A = nodual_value(cache.A)
partial_cache = cache.linear_cache
partial_cache.u0 = dual_u0
for i in eachindex(rhs_list)
partial_cache.b = rhs_list[i]
rhs_list[i] = copy(solve!(partial_cache, alg).u)
end

partial_sols = rhs_list

primal_sol, partial_sols
end

function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
A_list = partials_to_list(∂_A)
b_list = partials_to_list(∂_b)

Auu = [A * uu for A in A_list]

return b_list .- Auu
end

function xp_linsolve_rhs(
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing)
A_list = partials_to_list(∂_A)

Auu = [A * uu for A in A_list]

return -Auu
end

function xp_linsolve_rhs(
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
b_list = partials_to_list(∂_b)
b_list
end

function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
return solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
end

function SciMLBase.solve(prob::DualAbstractLinearProblem,
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
solve!(init(prob, alg, args...; kwargs...))
end

function linearsolve_dual_solution(
u::Number, partials, dual_type)
return dual_type(u, partials)
end

function linearsolve_dual_solution(
u::AbstractArray, partials, dual_type)
partials_list = RecursiveArrayTools.VectorOfArray(partials)
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
end

function SciMLBase.init(
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
args...;
alias = LinearAliasSpecifier(),
abstol = LinearSolve.default_tol(real(eltype(prob.b))),
reltol = LinearSolve.default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = nothing,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
kwargs...)

(; A, b, u0, p) = prob

new_A = nodual_value(A)
new_b = nodual_value(b)
new_u0 = nodual_value(u0)

∂_A = partial_vals(A)
∂_b = partial_vals(b)
dual_u0 = partial_vals(u0)

newprob = remake(prob; A = new_A, b = new_b, u0 = new_u0)

non_partial_cache = init(
newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
sensealg = sensealg, u0 = new_u0, kwargs...)
return DualLinearCache(non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b)
end

function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
sol,
partials = linearsolve_forwarddiff_solve(
cache::DualLinearCache, cache.alg, args...; kwargs...)

if get_dual_type(cache.prob.A) !== nothing
dual_type = get_dual_type(cache.prob.A)
elseif get_dual_type(cache.prob.b) !== nothing
dual_type = get_dual_type(cache.prob.b)
end

dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)

return SciMLBase.build_linear_solution(
cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
)
end

# If setting A or b for DualLinearCache, also set it for the underlying LinearCache
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b
if hasproperty(dc, :linear_cache)
setproperty!(dc.linear_cache, sym, nodual_value(val))
end
end

# Update the partials if setting A or b
if sym === :A
setfield!(dc, :partials_A, partial_vals(val))
elseif sym === :b
setfield!(dc, :partials_b, partial_vals(val))
else
setfield!(dc, sym, val)
end
end

function Base.getproperty(dc::DualLinearCache, sym::Symbol)
if sym === :A
return dc.linear_cache.A
elseif sym === :b
return dc.linear_cache.b
else
getfield(dc,sym)
end
end



# Helper functions for Dual numbers
get_dual_type(x::Dual) = typeof(x)
get_dual_type(x::AbstractArray{<:Dual}) = eltype(x)
get_dual_type(x) = nothing

partial_vals(x::Dual) = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
partial_vals(x) = nothing

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)


function partials_to_list(partial_matrix::Vector)
p = eachindex(first(partial_matrix))
[[partial[i] for partial in partial_matrix] for i in p]
end

function partials_to_list(partial_matrix)
p = length(first(partial_matrix))
m, n = size(partial_matrix)
res_list = fill(zeros(m, n), p)
for k in 1:p
res = zeros(m, n)
for i in 1:m
for j in 1:n
res[i, j] = partial_matrix[i, j][k]
end
end
res_list[k] = res
end
return res_list
end


end
41 changes: 41 additions & 0 deletions test/forwarddiff_overloads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using LinearSolve
using ForwardDiff
using Test

function h(p)
(A = [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]],
b = [p[1] + 1, p[2] * 2, p[1]^2])
end

A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
overload_x_p = solve(prob)
original_x_p = A \ b

@test ≈(overload_x_p, original_x_p, rtol = 1e-9)

A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
prob = LinearProblem(A, [6.0, 10.0, 25.0])
@test ≈(solve(prob).u, A \ [6.0, 10.0, 25.0], rtol = 1e-9)

_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0]
prob = LinearProblem(A, b)
@test ≈(solve(prob).u, A \ b, rtol = 1e-9)

A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)])

prob = LinearProblem(A, b)
cache = init(prob)

new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
cache.A = new_A
cache.b = new_b

x_p = solve!(cache)
other_x_p = new_A \ new_b

@test ≈(x_p, other_x_p, rtol = 1e-9)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core"
@time @safetestset "SparseVector b Tests" include("sparse_vector.jl")
@time @safetestset "Default Alg Tests" include("default_algs.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "ForwardDiff Overloads" include("forwarddiff_overloads.jl")
@time @safetestset "Traits" include("traits.jl")
@time @safetestset "BandedMatrices" include("banded.jl")
end
Expand Down
Loading