diff --git a/Project.toml b/Project.toml index 96e7145af..f4fcc2604 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastAlmostBandedMatrices = "9d29842c-ecb8-4973-b1e9-a27b1157504e" FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -53,6 +54,7 @@ LinearSolveCUDSSExt = "CUDSS" LinearSolveEnzymeExt = "EnzymeCore" LinearSolveFastAlmostBandedMatricesExt = "FastAlmostBandedMatrices" LinearSolveFastLapackInterfaceExt = "FastLapackInterface" +LinearSolveForwardDiffExt = "ForwardDiff" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl new file mode 100644 index 000000000..f2137eccb --- /dev/null +++ b/ext/LinearSolveForwardDiffExt.jl @@ -0,0 +1,241 @@ +module LinearSolveForwardDiffExt + +using LinearSolve +using LinearAlgebra +using ForwardDiff +using ForwardDiff: Dual, Partials +using SciMLBase +using RecursiveArrayTools + +const DualLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Any +} where {iip, T, V, P} + +const DualALinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, + iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Union{Number, <:AbstractArray}, + <:Any +} where {iip, T, V, P} + +const DualBLinearProblem = LinearProblem{ + <:Union{Number, <:AbstractArray, Nothing}, + iip, + <:Union{Number, <:AbstractArray}, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}, + <:Any +} where {iip, T, V, P} + +const DualAbstractLinearProblem = Union{ + DualLinearProblem, DualALinearProblem, DualBLinearProblem} + +LinearSolve.@concrete mutable struct DualLinearCache + linear_cache + dual_type + partials_A + partials_b +end + +function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...) + # Solve the primal problem + dual_u0 = copy(cache.linear_cache.u) + sol = solve!(cache.linear_cache, alg, args...; kwargs...) + primal_b = copy(cache.linear_cache.b) + uu = sol.u + + primal_sol = deepcopy(sol) + + # Solves Dual partials separately + ∂_A = cache.partials_A + ∂_b = cache.partials_b + + rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b) + + partial_cache = cache.linear_cache + partial_cache.u = dual_u0 + + for i in eachindex(rhs_list) + partial_cache.b = rhs_list[i] + rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u) + end + + # Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to + partial_cache.b = primal_b + + partial_sols = rhs_list + + primal_sol, partial_sols +end + +function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, + ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) + A_list = partials_to_list(∂_A) + b_list = partials_to_list(∂_b) + + Auu = [A * uu for A in A_list] + + return b_list .- Auu +end + +function xp_linsolve_rhs( + uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}}, ∂_b::Nothing) + A_list = partials_to_list(∂_A) + + Auu = [A * uu for A in A_list] + + return -Auu +end + +function xp_linsolve_rhs( + uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}) + b_list = partials_to_list(∂_b) + b_list +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...) + return solve(prob, nothing, args...; kwargs...) +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...; + assump = OperatorAssumptions(issquare(prob.A)), kwargs...) + return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...) +end + +function SciMLBase.solve(prob::DualAbstractLinearProblem, + alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...) + solve!(init(prob, alg, args...; kwargs...)) +end + +function linearsolve_dual_solution( + u::Number, partials, dual_type) + return dual_type(u, partials) +end + +function linearsolve_dual_solution( + u::AbstractArray, partials, dual_type) + partials_list = RecursiveArrayTools.VectorOfArray(partials) + return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), + zip(u, partials_list[i, :] for i in 1:length(partials_list[1]))) +end + +function SciMLBase.init( + prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm, + args...; + alias = LinearAliasSpecifier(), + abstol = LinearSolve.default_tol(real(eltype(prob.b))), + reltol = LinearSolve.default_tol(real(eltype(prob.b))), + maxiters::Int = length(prob.b), + verbose::Bool = false, + Pl = nothing, + Pr = nothing, + assumptions = OperatorAssumptions(issquare(prob.A)), + sensealg = LinearSolveAdjoint(), + kwargs...) + + (; A, b, u0, p) = prob + new_A = nodual_value(A) + new_b = nodual_value(b) + new_u0 = nodual_value(u0) + + ∂_A = partial_vals(A) + ∂_b = partial_vals(b) + + #primal_prob = LinearProblem(new_A, new_b, u0 = new_u0) + primal_prob = remake(prob; A = new_A, b = new_b, u0 = new_u0) + + if get_dual_type(prob.A) !== nothing + dual_type = get_dual_type(prob.A) + elseif get_dual_type(prob.b) !== nothing + dual_type = get_dual_type(prob.b) + end + + non_partial_cache = init( + primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol, + maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, + sensealg = sensealg, u0 = new_u0, kwargs...) + return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b) +end + +function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) + sol, + partials = linearsolve_forwarddiff_solve( + cache::DualLinearCache, cache.alg, args...; kwargs...) + + dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + return SciMLBase.build_linear_solution( + cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats + ) +end + +# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache +# Also "forwards" setproperty so that +function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val) + # If the property is A or b, also update it in the LinearCache + if sym === :A || sym === :b || sym === :u + setproperty!(dc.linear_cache, sym, nodual_value(val)) + elseif hasfield(LinearSolve.LinearCache, sym) + setproperty!(dc.linear_cache, sym, val) + end + + # Update the partials if setting A or b + if sym === :A + setfield!(dc, :partials_A, partial_vals(val)) + elseif sym === :b + setfield!(dc, :partials_b, partial_vals(val)) + else + setfield!(dc, sym, val) + end +end + +# "Forwards" getproperty to LinearCache if necessary +function Base.getproperty(dc::DualLinearCache, sym::Symbol) + if hasfield(LinearSolve.LinearCache, sym) + return getproperty(dc.linear_cache, sym) + else + return getfield(dc, sym) + end +end + + + +# Helper functions for Dual numbers +get_dual_type(x::Dual) = typeof(x) +get_dual_type(x::AbstractArray{<:Dual}) = eltype(x) +get_dual_type(x) = nothing + +partial_vals(x::Dual) = ForwardDiff.partials(x) +partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x) +partial_vals(x) = nothing + +nodual_value(x) = x +nodual_value(x::Dual) = ForwardDiff.value(x) +nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + + +function partials_to_list(partial_matrix::Vector) + p = eachindex(first(partial_matrix)) + [[partial[i] for partial in partial_matrix] for i in p] +end + +function partials_to_list(partial_matrix) + p = length(first(partial_matrix)) + m, n = size(partial_matrix) + res_list = fill(zeros(m, n), p) + for k in 1:p + res = zeros(m, n) + for i in 1:m + for j in 1:n + res[i, j] = partial_matrix[i, j][k] + end + end + res_list[k] = res + end + return res_list +end + + +end diff --git a/test/forwarddiff_overloads.jl b/test/forwarddiff_overloads.jl new file mode 100644 index 000000000..eb66c64dc --- /dev/null +++ b/test/forwarddiff_overloads.jl @@ -0,0 +1,82 @@ +using LinearSolve +using ForwardDiff +using Test + +function h(p) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) +end + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +overload_x_p = solve(prob) +backslash_x_p = A \ b +krylov_overload_x_p = solve(prob, KrylovJL_GMRES()) +@test ≈(overload_x_p, backslash_x_p, rtol = 1e-9) +@test ≈(krylov_overload_x_p, backslash_x_p, rtol = 1e-9) + +krylov_prob = LinearProblem(A, b, u0 = rand(3)) +krylov_u0_sol = solve(krylov_prob, KrylovJL_GMRES()) + +@test ≈(krylov_u0_sol, backslash_x_p, rtol = 1e-9) + + +A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +backslash_x_p = A \ [6.0, 10.0, 25.0] +prob = LinearProblem(A, [6.0, 10.0, 25.0]) + +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) + +_, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +A = [5.0 6.0 125.0; 15.0 10.0 21.0; 25.0 45.0 5.0] +backslash_x_p = A \ b +prob = LinearProblem(A, b) + +@test ≈(solve(prob).u, backslash_x_p, rtol = 1e-9) +@test ≈(solve(prob, KrylovJL_GMRES()).u, backslash_x_p, rtol = 1e-9) + +A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A +cache.b = new_b + +x_p = solve!(cache) +backslash_x_p = new_A \ new_b + +@test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Just update A +A, b = h([ForwardDiff.Dual(10.0, 1.0, 0.0), ForwardDiff.Dual(10.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +new_A, _ = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.A = new_A + +x_p = solve!(cache) +backslash_x_p = new_A \ b + +@test ≈(x_p, backslash_x_p, rtol = 1e-9) + +# Just update b +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +prob = LinearProblem(A, b) +cache = init(prob) + +_, new_b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) +cache.b = new_b + +x_p = solve!(cache) +backslash_x_p = A \ new_b + +@test ≈(x_p, backslash_x_p, rtol = 1e-9) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0d994f787..2133bcd4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "SparseVector b Tests" include("sparse_vector.jl") @time @safetestset "Default Alg Tests" include("default_algs.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") + @time @safetestset "ForwardDiff Overloads" include("forwarddiff_overloads.jl") @time @safetestset "Traits" include("traits.jl") @time @safetestset "BandedMatrices" include("banded.jl") end