Skip to content

Commit e2ef9ca

Browse files
Merge pull request #3751 from AayushSabharwal/as/initsys-optimization
refactor: significantly improve performance of `*Problem` generation
2 parents 96602e3 + 2efd558 commit e2ef9ca

16 files changed

+346
-100
lines changed

benchmark/benchmarks.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ModelingToolkitStandardLibrary.Electrical
44
using ModelingToolkitStandardLibrary.Mechanical.Rotational
55
using ModelingToolkitStandardLibrary.Blocks
66
using OrdinaryDiffEqDefault
7+
using ModelingToolkit: t_nounits as t, D_nounits as D
78

89
const SUITE = BenchmarkGroup()
910

@@ -45,12 +46,33 @@ end
4546

4647
@named model = DCMotor()
4748

49+
# first call
50+
mtkcompile(model)
4851
SUITE["mtkcompile"] = @benchmarkable mtkcompile($model)
4952

5053
model = mtkcompile(model)
5154
u0 = unknowns(model) .=> 0.0
5255
tspan = (0.0, 6.0)
53-
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
5456

5557
prob = ODEProblem(model, u0, tspan)
58+
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
59+
60+
# first call
61+
init(prob)
5662
SUITE["init"] = @benchmarkable init($prob)
63+
64+
large_param_init = SUITE["large_parameter_init"] = BenchmarkGroup()
65+
66+
N = 25
67+
@variables x(t)[1:N]
68+
@parameters A[1:N, 1:N]
69+
70+
defval = collect(x) * collect(x)'
71+
@mtkcompile model = System(
72+
[D(x) ~ x], t, [x], [A]; defaults = [A => defval], guesses = [A => fill(NaN, N, N)])
73+
74+
u0 = [x => rand(N)]
75+
prob = ODEProblem(model, u0, tspan)
76+
large_param_init["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
77+
78+
large_param_init["init"] = @benchmarkable init($prob)

src/bipartite_graph.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
535535
end
536536
end
537537

538-
function delete_srcs!(g::BipartiteGraph, srcs)
538+
function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I}
539539
for s in srcs
540540
set_neighbors!(g, s, ())
541541
end
542+
if rm_verts
543+
old_to_new_idxs = collect(one(I):I(nsrcs(g)))
544+
for s in srcs
545+
old_to_new_idxs[s] = zero(I)
546+
end
547+
offset = zero(I)
548+
for i in eachindex(old_to_new_idxs)
549+
if iszero(old_to_new_idxs[i])
550+
offset += one(I)
551+
continue
552+
end
553+
old_to_new_idxs[i] -= offset
554+
end
555+
556+
if g.badjlist isa AbstractVector
557+
for i in 1:ndsts(g)
558+
for j in eachindex(g.badjlist[i])
559+
g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]]
560+
end
561+
filter!(!iszero, g.badjlist[i])
562+
end
563+
end
564+
deleteat!(g.fadjlist, srcs)
565+
end
542566
g
543567
end
544-
delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs)
568+
function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false)
569+
delete_srcs!(invview(g), srcs; rm_verts)
570+
end
545571

546572
###
547573
### Edges iteration

src/problems/initializationproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
3939
for k in keys(op)
4040
has_u0_ics |= is_variable(sys, k) || isdifferential(k) ||
4141
symbolic_type(k) == ArraySymbolic() &&
42-
is_sized_array_symbolic(k) && is_variable(sys, first(collect(k)))
42+
is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k))))
4343
end
4444
if !has_u0_ics && get_initializesystem(sys) !== nothing
4545
isys = get_initializesystem(sys; initialization_eqs, check_units)
@@ -79,7 +79,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
7979
@warn errmsg
8080
end
8181

82-
uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)])
82+
uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys))
8383

8484
# TODO: throw on uninitialized arrays
8585
filter!(x -> !(x isa Symbolics.Arr), uninit)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ function update_simplified_system!(
960960
obs_sub[eq.lhs] = eq.rhs
961961
end
962962
# TODO: compute the dependency correctly so that we don't have to do this
963-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
963+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
964+
fast_substitute(state.additional_observed, obs_sub)]
964965

965966
unknown_idxs = filter(
966967
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))

src/systems/abstractsystem.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,23 @@ has_equations(::AbstractSystem) = true
805805
806806
Invalidate cached jacobians, etc.
807807
"""
808-
invalidate_cache!(sys::AbstractSystem) = sys
808+
function invalidate_cache!(sys::AbstractSystem)
809+
has_metadata(sys) || return sys
810+
empty!(getmetadata(sys, MutableCacheKey, nothing))
811+
return sys
812+
end
813+
814+
# `::MetadataT` but that is defined later
815+
function refreshed_metadata(meta::Base.ImmutableDict)
816+
newmeta = MetadataT()
817+
for (k, v) in meta
818+
if k === MutableCacheKey
819+
v = MutableCacheT()
820+
end
821+
newmeta = Base.ImmutableDict(newmeta, k => v)
822+
end
823+
return newmeta
824+
end
809825

810826
function Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field}
811827
getfield(obj, field)
@@ -815,6 +831,8 @@ end
815831
args = map(fieldnames(obj)) do fn
816832
if fn in fieldnames(patch)
817833
:(patch.$fn)
834+
elseif fn == :metadata
835+
:($refreshed_metadata(getfield(obj, $(Meta.quot(fn)))))
818836
else
819837
:(getfield(obj, $(Meta.quot(fn))))
820838
end
@@ -2507,7 +2525,15 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
25072525
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
25082526
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
25092527
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
2510-
meta = merge(get_metadata(basesys), get_metadata(sys))
2528+
meta = MetadataT()
2529+
for kvp in get_metadata(basesys)
2530+
kvp[1] == MutableCacheKey && continue
2531+
meta = Base.ImmutableDict(meta, kvp)
2532+
end
2533+
for kvp in get_metadata(sys)
2534+
kvp[1] == MutableCacheKey && continue
2535+
meta = Base.ImmutableDict(meta, kvp)
2536+
end
25112537
syss = union(get_systems(basesys), get_systems(sys))
25122538
args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps)
25132539
kwargs = (observed = obs, continuous_events = cevs,
@@ -2705,7 +2731,9 @@ function process_parameter_equations(sys::AbstractSystem)
27052731
is_sized_array_symbolic(sym) &&
27062732
all(Base.Fix1(is_parameter, sys), collect(sym))
27072733
end
2708-
if !isparameter(eq.lhs)
2734+
# Everything in `varsbuf` is a parameter, so this is a cheap `is_parameter`
2735+
# check.
2736+
if !(eq.lhs in varsbuf)
27092737
throw(ArgumentError("""
27102738
LHS of parameter dependency equation must be a single parameter. Found \
27112739
$(eq.lhs).

src/systems/connectors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,8 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
872872
eqs = [equations(sys); ceqs; stream_eqs]
873873
# substitute `instream(..)` expressions with their new values
874874
for i in eachindex(eqs)
875-
eqs[i] = fixpoint_sub(eqs[i], instream_subs; maxiters = length(instream_subs))
875+
eqs[i] = fixpoint_sub(
876+
eqs[i], instream_subs; maxiters = max(length(instream_subs), 10))
876877
end
877878
# get the defaults for domain networks
878879
d_defs = domain_defaults(sys, domain_csets)

src/systems/nonlinear/initializesystem.jl

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,20 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
173173
end
174174

175175
# 5) process parameters as initialization unknowns
176-
paramsubs = setup_parameter_initialization!(
176+
solved_params = setup_parameter_initialization!(
177177
sys, pmap, defs, guesses, eqs_ics; check_defguess)
178178

179179
# 6) parameter dependencies become equations, their LHS become unknowns
180180
# non-numeric dependent parameters stay as parameter dependencies
181181
new_parameter_deps = solve_parameter_dependencies!(
182-
sys, paramsubs, eqs_ics, defs, guesses)
182+
sys, solved_params, eqs_ics, defs, guesses)
183183

184184
# 7) handle values provided for dependent parameters similar to values for observed variables
185-
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics, paramsubs)
185+
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics)
186186

187187
# parameters do not include ones that became initialization unknowns
188188
pars = Vector{SymbolicParam}(filter(
189-
p -> !haskey(paramsubs, p), parameters(sys; initial_parameters = true)))
189+
!in(solved_params), parameters(sys; initial_parameters = true)))
190190
push!(pars, get_iv(sys))
191191

192192
# 8) use observed equations for guesses of observed variables if not provided
@@ -198,16 +198,8 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
198198
end
199199
append!(eqs_ics, trueobs)
200200

201-
vars = [vars; collect(values(paramsubs))]
201+
vars = [vars; collect(solved_params)]
202202

203-
# even if `p => tovar(p)` is in `paramsubs`, `isparameter(p[1]) === true` after substitution
204-
# so add scalarized versions as well
205-
scalarize_varmap!(paramsubs)
206-
207-
eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
208-
for k in keys(defs)
209-
defs[k] = substitute(defs[k], paramsubs)
210-
end
211203
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
212204
merge!(defs, initials)
213205
isys = System(Vector{Equation}(eqs_ics),
@@ -299,30 +291,22 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
299291
append!(eqs_ics, initialization_eqs)
300292

301293
# process parameters as initialization unknowns
302-
paramsubs = setup_parameter_initialization!(
294+
solved_params = setup_parameter_initialization!(
303295
sys, pmap, defs, guesses, eqs_ics; check_defguess)
304296

305297
# parameter dependencies become equations, their LHS become unknowns
306298
# non-numeric dependent parameters stay as parameter dependencies
307299
new_parameter_deps = solve_parameter_dependencies!(
308-
sys, paramsubs, eqs_ics, defs, guesses)
300+
sys, solved_params, eqs_ics, defs, guesses)
309301

310302
# handle values provided for dependent parameters similar to values for observed variables
311-
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics, paramsubs)
303+
handle_dependent_parameter_constraints!(sys, pmap, eqs_ics)
312304

313305
# parameters do not include ones that became initialization unknowns
314306
pars = Vector{SymbolicParam}(filter(
315-
p -> !haskey(paramsubs, p), parameters(sys; initial_parameters = true)))
316-
vars = collect(values(paramsubs))
317-
318-
# even if `p => tovar(p)` is in `paramsubs`, `isparameter(p[1]) === true` after substitution
319-
# so add scalarized versions as well
320-
scalarize_varmap!(paramsubs)
307+
!in(solved_params), parameters(sys; initial_parameters = true)))
308+
vars = collect(solved_params)
321309

322-
eqs_ics = Vector{Equation}(Symbolics.substitute.(eqs_ics, (paramsubs,)))
323-
for k in keys(defs)
324-
defs[k] = substitute(defs[k], paramsubs)
325-
end
326310
initials = Dict(k => v for (k, v) in pmap if isinitial(k))
327311
merge!(defs, initials)
328312
isys = System(Vector{Equation}(eqs_ics),
@@ -359,7 +343,7 @@ mapping solvable parameters to their `tovar` variants.
359343
function setup_parameter_initialization!(
360344
sys::AbstractSystem, pmap::AbstractDict, defs::AbstractDict,
361345
guesses::AbstractDict, eqs_ics::Vector{Equation}; check_defguess = false)
362-
paramsubs = Dict()
346+
solved_params = Set()
363347
for p in parameters(sys)
364348
if is_parameter_solvable(p, pmap, defs, guesses)
365349
# If either of them are `missing` the parameter is an unknown
@@ -369,7 +353,7 @@ function setup_parameter_initialization!(
369353
_val2 = get_possibly_array_fallback_singletons(defs, p)
370354
_val3 = get_possibly_array_fallback_singletons(guesses, p)
371355
varp = tovar(p)
372-
paramsubs[p] = varp
356+
push!(solved_params, p)
373357
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
374358
if _val2 === missing
375359
if _val1 !== nothing && _val1 !== missing
@@ -409,7 +393,7 @@ function setup_parameter_initialization!(
409393
end
410394
end
411395

412-
return paramsubs
396+
return solved_params
413397
end
414398

415399
"""
@@ -418,7 +402,7 @@ end
418402
Add appropriate parameter dependencies as initialization equations. Return the new list of
419403
parameter dependencies for the initialization system.
420404
"""
421-
function solve_parameter_dependencies!(sys::AbstractSystem, paramsubs::AbstractDict,
405+
function solve_parameter_dependencies!(sys::AbstractSystem, solved_params::AbstractSet,
422406
eqs_ics::Vector{Equation}, defs::AbstractDict, guesses::AbstractDict)
423407
new_parameter_deps = Equation[]
424408
for eq in parameter_dependencies(sys)
@@ -427,7 +411,7 @@ function solve_parameter_dependencies!(sys::AbstractSystem, paramsubs::AbstractD
427411
continue
428412
end
429413
varp = tovar(eq.lhs)
430-
paramsubs[eq.lhs] = varp
414+
push!(solved_params, eq.lhs)
431415
push!(eqs_ics, eq)
432416
guessval = get(guesses, eq.lhs, eq.rhs)
433417
push!(defs, varp => guessval)
@@ -442,10 +426,10 @@ end
442426
Turn values provided for parameter dependencies into initialization equations.
443427
"""
444428
function handle_dependent_parameter_constraints!(sys::AbstractSystem, pmap::AbstractDict,
445-
eqs_ics::Vector{Equation}, paramsubs::AbstractDict)
429+
eqs_ics::Vector{Equation})
446430
for (k, v) in merge(defaults(sys), pmap)
447431
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
448-
push!(eqs_ics, paramsubs[k] ~ v)
432+
push!(eqs_ics, k ~ v)
449433
end
450434
end
451435

@@ -735,7 +719,25 @@ function SciMLBase.late_binding_update_u0_p(
735719
newu0, newp = promote_u0_p(newu0, newp, t0)
736720

737721
# non-symbolic u0 updates initials...
738-
if !(eltype(u0) <: Pair)
722+
if eltype(u0) <: Pair
723+
syms = []
724+
vals = []
725+
allsyms = all_symbols(sys)
726+
for (k, v) in u0
727+
v === nothing && continue
728+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
729+
if k isa Symbol
730+
k2 = symbol_to_symbolic(sys, k; allsyms)
731+
# if it is returned as-is, there is no match so skip it
732+
k2 === k && continue
733+
k = k2
734+
end
735+
is_parameter(sys, Initial(k)) || continue
736+
push!(syms, Initial(k))
737+
push!(vals, v)
738+
end
739+
newp = setp_oop(sys, syms)(newp, vals)
740+
else
739741
# if `p` is not provided or is symbolic
740742
p === missing || eltype(p) <: Pair || return newu0, newp
741743
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -748,27 +750,27 @@ function SciMLBase.late_binding_update_u0_p(
748750
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
749751
end
750752
newp = meta.set_initial_unknowns!(newp, newu0)
751-
return newu0, newp
752-
end
753-
754-
syms = []
755-
vals = []
756-
allsyms = all_symbols(sys)
757-
for (k, v) in u0
758-
v === nothing && continue
759-
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
760-
if k isa Symbol
761-
k2 = symbol_to_symbolic(sys, k; allsyms)
762-
# if it is returned as-is, there is no match so skip it
763-
k2 === k && continue
764-
k = k2
753+
end
754+
755+
if eltype(p) <: Pair
756+
syms = []
757+
vals = []
758+
for (k, v) in p
759+
v === nothing && continue
760+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
761+
if k isa Symbol
762+
k2 = symbol_to_symbolic(sys, k; allsyms)
763+
# if it is returned as-is, there is no match so skip it
764+
k2 === k && continue
765+
k = k2
766+
end
767+
is_parameter(sys, Initial(k)) || continue
768+
push!(syms, Initial(k))
769+
push!(vals, v)
765770
end
766-
is_parameter(sys, Initial(k)) || continue
767-
push!(syms, Initial(k))
768-
push!(vals, v)
771+
newp = setp_oop(sys, syms)(newp, vals)
769772
end
770773

771-
newp = setp_oop(sys, syms)(newp, vals)
772774
return newu0, newp
773775
end
774776

0 commit comments

Comments
 (0)