Skip to content

Commit 2922aa7

Browse files
fix: properly handle values given to parameter dependencies in late_binding_update_u0_p
1 parent 592f3f3 commit 2922aa7

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,26 @@ function SciMLBase.late_binding_update_u0_p(
648648
newu0, newp = promote_u0_p(newu0, newp, t0)
649649

650650
# non-symbolic u0 updates initials...
651-
if !(eltype(u0) <: Pair)
651+
if eltype(u0) <: Pair
652+
syms = []
653+
vals = []
654+
allsyms = all_symbols(sys)
655+
for (k, v) in u0
656+
v === nothing && continue
657+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
658+
if k isa Symbol
659+
k2 = symbol_to_symbolic(sys, k; allsyms)
660+
# if it is returned as-is, there is no match so skip it
661+
k2 === k && continue
662+
k = k2
663+
end
664+
is_parameter(sys, Initial(k)) || continue
665+
push!(syms, Initial(k))
666+
push!(vals, v)
667+
end
668+
newp = setp_oop(sys, syms)(newp, vals)
669+
else
670+
allsyms = nothing
652671
# if `p` is not provided or is symbolic
653672
p === missing || eltype(p) <: Pair || return newu0, newp
654673
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -661,27 +680,30 @@ function SciMLBase.late_binding_update_u0_p(
661680
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
662681
end
663682
newp = meta.set_initial_unknowns!(newp, newu0)
664-
return newu0, newp
665-
end
666-
667-
syms = []
668-
vals = []
669-
allsyms = all_symbols(sys)
670-
for (k, v) in u0
671-
v === nothing && continue
672-
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
673-
if k isa Symbol
674-
k2 = symbol_to_symbolic(sys, k; allsyms)
675-
# if it is returned as-is, there is no match so skip it
676-
k2 === k && continue
677-
k = k2
683+
end
684+
685+
if eltype(p) <: Pair
686+
syms = []
687+
vals = []
688+
if allsyms === nothing
689+
allsyms = all_symbols(sys)
690+
end
691+
for (k, v) in p
692+
v === nothing && continue
693+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
694+
if k isa Symbol
695+
k2 = symbol_to_symbolic(sys, k; allsyms)
696+
# if it is returned as-is, there is no match so skip it
697+
k2 === k && continue
698+
k = k2
699+
end
700+
is_parameter(sys, Initial(k)) || continue
701+
push!(syms, Initial(k))
702+
push!(vals, v)
678703
end
679-
is_parameter(sys, Initial(k)) || continue
680-
push!(syms, Initial(k))
681-
push!(vals, v)
704+
newp = setp_oop(sys, syms)(newp, vals)
682705
end
683706

684-
newp = setp_oop(sys, syms)(newp, vals)
685707
return newu0, newp
686708
end
687709

test/initializationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,9 +1253,9 @@ end
12531253
@test init(prob3)[x] 1.0
12541254
prob4 = remake(prob; p = [p => 1.0])
12551255
test_dummy_initialization_equation(prob4, x)
1256-
prob5 = remake(prob; p = [p => missing, q => 2.0])
1256+
prob5 = remake(prob; p = [p => missing, q => 4.0])
12571257
@test prob5.f.initialization_data !== nothing
1258-
@test init(prob5).ps[p] 1.0
1258+
@test init(prob5).ps[p] 2.0
12591259
end
12601260

12611261
@testset "Variables provided as symbols" begin

0 commit comments

Comments
 (0)