@@ -36,12 +36,13 @@ const DualAbstractLinearProblem = Union{
36
36
LinearSolve. @concrete mutable struct DualLinearCache
37
37
linear_cache
38
38
dual_type
39
- dual_u0
40
39
partials_A
41
40
partials_b
42
41
end
43
42
44
43
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
44
+ # Solve the primal problem
45
+ dual_u0 = copy (cache. linear_cache. u)
45
46
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
46
47
primal_b = copy (cache. linear_cache. b)
47
48
uu = sol. u
@@ -51,7 +52,6 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
51
52
# Solves Dual partials separately
52
53
∂_A = cache. partials_A
53
54
∂_b = cache. partials_b
54
- dual_u0 = ! isnothing (cache. dual_u0) ? only (partials_to_list (cache. dual_u0)) : cache. linear_cache. u
55
55
56
56
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
57
57
@@ -137,14 +137,12 @@ function SciMLBase.init(
137
137
kwargs... )
138
138
139
139
(; A, b, u0, p) = prob
140
-
141
140
new_A = nodual_value (A)
142
141
new_b = nodual_value (b)
143
142
new_u0 = nodual_value (u0)
144
143
145
144
∂_A = partial_vals (A)
146
145
∂_b = partial_vals (b)
147
- dual_u0 = partial_vals (u0)
148
146
149
147
primal_prob = LinearProblem (new_A, new_b, u0 = new_u0)
150
148
# remake(prob; A = new_A, b = new_b, u0 = new_u0)
@@ -159,7 +157,7 @@ function SciMLBase.init(
159
157
primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160
158
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161
159
sensealg = sensealg, u0 = new_u0, kwargs... )
162
- return DualLinearCache (non_partial_cache, dual_type, dual_u0, ∂_A, ∂_b)
160
+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b)
163
161
end
164
162
165
163
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
@@ -168,9 +166,8 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
168
166
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
169
167
170
168
dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
171
-
172
169
return SciMLBase. build_linear_solution (
173
- cache. alg, dual_sol, sol. resid, sol . cache; sol. retcode, sol. iters, sol. stats
170
+ cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
174
171
)
175
172
end
176
173
0 commit comments