Skip to content

Commit 8899c1d

Browse files
feat: enable OverrideInit to solve for du0 of DAEProblems
1 parent a5ee8e9 commit 8899c1d

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

src/initialization.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
A collection of all the data required for `OverrideInit`.
55
"""
6-
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
6+
struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap, IProbDu0Map}
77
"""
88
The `AbstractNonlinearProblem` to solve for initialization.
99
"""
@@ -29,12 +29,18 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
2929
initialized will be returned as-is.
3030
"""
3131
initializeprobpmap::IProbPmap
32+
"""
33+
A function which takes the solution of `initializeprob` and returns the
34+
`du0` vector of the original problem.
35+
"""
36+
initializeprob_du0map::IProbDu0Map
3237

3338
function OverrideInitData(initprob::I, update_initprob!::J, initprobmap::K,
34-
initprobpmap::L) where {I, J, K, L}
39+
initprobpmap::L, initprob_du0map::M = nothing) where {I, J, K, L, M}
3540
@assert initprob isa
3641
Union{SCCNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem}
37-
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
42+
return new{I, J, K, L, M}(
43+
initprob, update_initprob!, initprobmap, initprobpmap, initprob_du0map)
3844
end
3945
end
4046

@@ -171,9 +177,12 @@ Keyword arguments:
171177
provided to the `OverrideInit` constructor takes priority over this keyword argument.
172178
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
173179
an error will be thrown.
180+
- `return_du0`: Whether to use `initializeprob_du0map` (if present) and return
181+
`du0, u0, p, success`.
174182
"""
175183
function get_initial_values(prob, valp, f, alg::OverrideInit,
176-
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
184+
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing,
185+
reltol = nothing, return_du0 = false, kwargs...)
177186
u0 = state_values(valp)
178187
p = parameter_values(valp)
179188

@@ -214,5 +223,10 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
214223
p = initdata.initializeprobpmap(valp, nlsol)
215224
end
216225

226+
if return_du0
227+
du0 = initdata.initializeprob_du0map === nothing ? nothing : initdata.initializeprob_du0map(nlsol)
228+
return du0, u0, p, SciMLBase.successful_retcode(nlsol)
229+
end
230+
217231
return u0, p, SciMLBase.successful_retcode(nlsol)
218232
end

test/initialization.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,66 @@ end
229229
@test p 0.0
230230
@test success
231231
end
232+
233+
@testset "DAEProblem" begin
234+
function daerhs(du, u, p, t)
235+
return [u[1] * t + p, u[1]^2 - u[2]^2]
236+
end
237+
# unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t
238+
initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p
239+
u2, p, du1, du2 = x
240+
u1, t = _p
241+
return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2]
242+
end
243+
244+
update_initializeprob! = function (iprob, integ)
245+
iprob.p[1] = integ.u[1]
246+
iprob.p[2] = integ.t
247+
end
248+
initprobmap = function (nlsol)
249+
return [parameter_values(nlsol)[1], nlsol.u[1]]
250+
end
251+
initprobpmap = function (_, nlsol)
252+
return nlsol.u[2]
253+
end
254+
initprob_du0map = function (nlsol)
255+
return nlsol.u[3:4]
256+
end
257+
initialization_data = SciMLBase.OverrideInitData(
258+
initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map)
259+
fn = DAEFunction(daerhs; initialization_data)
260+
prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
261+
integ = init(prob, DImplicitEuler(); initializealg = NoInit())
262+
263+
initialization_data2 = SciMLBase.OverrideInitData(
264+
initprob, update_initializeprob!, initprobmap, initprobpmap)
265+
fn2 = DAEFunction(daerhs; initialization_data = initialization_data2)
266+
prob2 = DAEProblem(fn2, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0)
267+
integ2 = init(prob2, DImplicitEuler(); initializealg = NoInit())
268+
269+
nlsolve_alg = FastShortcutNonlinearPolyalg()
270+
@testset "Doesn't return `du0` by default" begin
271+
@test length(SciMLBase.get_initial_values(
272+
prob, integ, fn, SciMLBase.OverrideInit(),
273+
Val(false); nlsolve_alg, abstol, reltol)) == 3
274+
end
275+
@testset "`du0 === nothing` if missing `du0map`" begin
276+
du0, u0, p, success = SciMLBase.get_initial_values(
277+
prob2, integ2, fn2, SciMLBase.OverrideInit(), Val(false);
278+
nlsolve_alg, abstol, reltol, return_du0 = true)
279+
@test du0 === nothing
280+
@test u0 [2.0, 2.0]
281+
@test p 1.0
282+
@test success
283+
end
284+
@testset "With `return_du0 = true`" begin
285+
du0, u0, p, success = SciMLBase.get_initial_values(
286+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false);
287+
nlsolve_alg, abstol, reltol, return_du0 = true)
288+
@test du0 [1.0, 1.0]
289+
@test u0 [2.0, 2.0]
290+
@test p 1.0
291+
@test success
292+
end
293+
end
232294
end

0 commit comments

Comments
 (0)