Skip to content

Commit d05ad09

Browse files
committed
use correct u0
1 parent 6f33486 commit d05ad09

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@ const DualAbstractLinearProblem = Union{
3636
LinearSolve.@concrete mutable struct DualLinearCache
3737
linear_cache
3838
dual_type
39-
dual_u0
4039
partials_A
4140
partials_b
4241
end
4342

4443
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
44+
# Solve the primal problem
45+
dual_u0 = copy(cache.linear_cache.u)
4546
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
4647
primal_b = copy(cache.linear_cache.b)
4748
uu = sol.u
@@ -51,7 +52,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5152
# Solves Dual partials separately
5253
∂_A = cache.partials_A
5354
∂_b = cache.partials_b
54-
dual_u0 = !isnothing(cache.dual_u0) ? only(partials_to_list(cache.dual_u0)) : cache.linear_cache.u
5555

5656
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5757

@@ -137,14 +137,12 @@ function SciMLBase.init(
137137
kwargs...)
138138

139139
(; A, b, u0, p) = prob
140-
141140
new_A = nodual_value(A)
142141
new_b = nodual_value(b)
143142
new_u0 = nodual_value(u0)
144143

145144
∂_A = partial_vals(A)
146145
∂_b = partial_vals(b)
147-
dual_u0 = partial_vals(u0)
148146

149147
primal_prob = LinearProblem(new_A, new_b, u0 = new_u0)
150148
#remake(prob; A = new_A, b = new_b, u0 = new_u0)
@@ -159,7 +157,7 @@ function SciMLBase.init(
159157
primal_prob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
160158
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161159
sensealg = sensealg, u0 = new_u0, kwargs...)
162-
return DualLinearCache(non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b)
160+
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b)
163161
end
164162

165163
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
@@ -168,9 +166,8 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
168166
cache::DualLinearCache, cache.alg, args...; kwargs...)
169167

170168
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
171-
172169
return SciMLBase.build_linear_solution(
173-
cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
170+
cache.alg, dual_sol, sol.resid, cache; sol.retcode, sol.iters, sol.stats
174171
)
175172
end
176173

0 commit comments

Comments
 (0)