Skip to content

Commit a5ce361

Browse files
fix: properly handle values given to parameter dependencies in late_binding_update_u0_p
1 parent 592f3f3 commit a5ce361

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,25 @@ function SciMLBase.late_binding_update_u0_p(
648648
newu0, newp = promote_u0_p(newu0, newp, t0)
649649

650650
# non-symbolic u0 updates initials...
651-
if !(eltype(u0) <: Pair)
651+
if eltype(u0) <: Pair
652+
syms = []
653+
vals = []
654+
allsyms = all_symbols(sys)
655+
for (k, v) in u0
656+
v === nothing && continue
657+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
658+
if k isa Symbol
659+
k2 = symbol_to_symbolic(sys, k; allsyms)
660+
# if it is returned as-is, there is no match so skip it
661+
k2 === k && continue
662+
k = k2
663+
end
664+
is_parameter(sys, Initial(k)) || continue
665+
push!(syms, Initial(k))
666+
push!(vals, v)
667+
end
668+
newp = setp_oop(sys, syms)(newp, vals)
669+
else
652670
# if `p` is not provided or is symbolic
653671
p === missing || eltype(p) <: Pair || return newu0, newp
654672
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -661,27 +679,27 @@ function SciMLBase.late_binding_update_u0_p(
661679
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
662680
end
663681
newp = meta.set_initial_unknowns!(newp, newu0)
664-
return newu0, newp
665-
end
666-
667-
syms = []
668-
vals = []
669-
allsyms = all_symbols(sys)
670-
for (k, v) in u0
671-
v === nothing && continue
672-
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
673-
if k isa Symbol
674-
k2 = symbol_to_symbolic(sys, k; allsyms)
675-
# if it is returned as-is, there is no match so skip it
676-
k2 === k && continue
677-
k = k2
682+
end
683+
684+
if eltype(p) <: Pair
685+
syms = []
686+
vals = []
687+
for (k, v) in p
688+
v === nothing && continue
689+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
690+
if k isa Symbol
691+
k2 = symbol_to_symbolic(sys, k; allsyms)
692+
# if it is returned as-is, there is no match so skip it
693+
k2 === k && continue
694+
k = k2
695+
end
696+
is_parameter(sys, Initial(k)) || continue
697+
push!(syms, Initial(k))
698+
push!(vals, v)
678699
end
679-
is_parameter(sys, Initial(k)) || continue
680-
push!(syms, Initial(k))
681-
push!(vals, v)
700+
newp = setp_oop(sys, syms)(newp, vals)
682701
end
683702

684-
newp = setp_oop(sys, syms)(newp, vals)
685703
return newu0, newp
686704
end
687705

test/initializationsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,9 +1253,9 @@ end
12531253
@test init(prob3)[x] 1.0
12541254
prob4 = remake(prob; p = [p => 1.0])
12551255
test_dummy_initialization_equation(prob4, x)
1256-
prob5 = remake(prob; p = [p => missing, q => 2.0])
1256+
prob5 = remake(prob; p = [p => missing, q => 4.0])
12571257
@test prob5.f.initialization_data !== nothing
1258-
@test init(prob5).ps[p] 1.0
1258+
@test init(prob5).ps[p] 2.0
12591259
end
12601260

12611261
@testset "Variables provided as symbols" begin

0 commit comments

Comments
 (0)