Skip to content

Commit 8029d21

Browse files
Merge pull request #3754 from AayushSabharwal/as/v9-fix-g-u-s-p
[v9] fix: fix `get_updated_symbolic_problem`
2 parents 897524e + a91cb22 commit 8029d21

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,18 @@ function SciMLBase.late_binding_update_u0_p(
685685
return newu0, newp
686686
end
687687

688-
function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw...)
688+
function DiffEqBase.get_updated_symbolic_problem(
689+
sys::AbstractSystem, prob; u0 = state_values(prob),
690+
p = parameter_values(prob), kw...)
689691
supports_initialization(sys) || return prob
690692
initdata = prob.f.initialization_data
691693
initdata isa SciMLBase.OverrideInitData || return prob
692694
meta = initdata.metadata
693695
meta isa InitializationMetadata || return prob
694696
meta.get_updated_u0 === nothing && return prob
695697

696-
u0 = state_values(prob)
697-
u0 === nothing && return prob
698+
u0 === nothing && return remake(prob; p)
698699

699-
p = parameter_values(prob)
700700
t0 = is_time_dependent(prob) ? current_time(prob) : nothing
701701

702702
if p isa MTKParameters
@@ -713,7 +713,7 @@ function DiffEqBase.get_updated_symbolic_problem(sys::AbstractSystem, prob; kw..
713713
T = StaticArrays.similar_type(u0)
714714
end
715715

716-
return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)))
716+
return remake(prob; u0 = T(meta.get_updated_u0(prob, initdata.initializeprob)), p)
717717
end
718718

719719
"""

test/extensions/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ OrdinaryDiffEqVerner = "79d7bb75-1356-48c1-b8c0-6832512096c2"
2323
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2424
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
2525
SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd"
26+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2627
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2728
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2829
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/extensions/ad.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using OrdinaryDiffEqNonlinearSolve
88
using NonlinearSolve
99
using SciMLSensitivity
1010
using ForwardDiff
11+
using StableRNGs
1112
using ChainRulesCore
1213
using ChainRulesCore: NoTangent
1314
using ChainRulesTestUtils: test_rrule, rand_tangent
@@ -135,3 +136,46 @@ end
135136
prob[sys.x]
136137
end
137138
end
139+
140+
@testset "`p` provided to `solve` is respected" begin
141+
@mtkmodel Linear begin
142+
@variables begin
143+
x(t) = 1.0, [description = "Prey"]
144+
end
145+
@parameters begin
146+
α = 1.5
147+
end
148+
@equations begin
149+
D(x) ~ -α * x
150+
end
151+
end
152+
153+
@mtkbuild linear = Linear()
154+
problem = ODEProblem(linear, [], (0.0, 1.0))
155+
solution = solve(problem, Tsit5(), saveat = 0.1)
156+
rng = StableRNG(42)
157+
data = (;
158+
t = solution.t,
159+
# [[y, x], :]
160+
measurements = Array(solution)
161+
)
162+
data.measurements .+= 0.05 * randn(rng, size(data.measurements))
163+
164+
p0, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), problem.p)
165+
166+
objective = let repack = repack, problem = problem
167+
(p, data) -> begin
168+
pnew = repack(p)
169+
sol = solve(problem, Tsit5(), p = pnew, saveat = data.t)
170+
sum(abs2, sol .- data.measurements) / size(data.t, 1)
171+
end
172+
end
173+
174+
# Check 0.0031677344878386607
175+
@test_nowarn objective(p0, data)
176+
177+
fd = ForwardDiff.gradient(Base.Fix2(objective, data), p0)
178+
zg = Zygote.gradient(Base.Fix2(objective, data), p0)
179+
180+
@test fdzg[1] atol=1e-6
181+
end

0 commit comments

Comments
 (0)