Skip to content

Commit 4a3d069

Browse files
Merge pull request #3771 from AayushSabharwal/as/v9-perf
[v9] feat: significantly improve performance of `*Problem` generation
2 parents 4578b66 + a2deaa4 commit 4a3d069

File tree

9 files changed

+234
-49
lines changed

9 files changed

+234
-49
lines changed

src/bipartite_graph.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
535535
end
536536
end
537537

538-
function delete_srcs!(g::BipartiteGraph, srcs)
538+
function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I}
539539
for s in srcs
540540
set_neighbors!(g, s, ())
541541
end
542+
if rm_verts
543+
old_to_new_idxs = collect(one(I):I(nsrcs(g)))
544+
for s in srcs
545+
old_to_new_idxs[s] = zero(I)
546+
end
547+
offset = zero(I)
548+
for i in eachindex(old_to_new_idxs)
549+
if iszero(old_to_new_idxs[i])
550+
offset += one(I)
551+
continue
552+
end
553+
old_to_new_idxs[i] -= offset
554+
end
555+
556+
if g.badjlist isa AbstractVector
557+
for i in 1:ndsts(g)
558+
for j in eachindex(g.badjlist[i])
559+
g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]]
560+
end
561+
filter!(!iszero, g.badjlist[i])
562+
end
563+
end
564+
deleteat!(g.fadjlist, srcs)
565+
end
542566
g
543567
end
544-
delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs)
568+
function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false)
569+
delete_srcs!(invview(g), srcs; rm_verts)
570+
end
545571

546572
###
547573
### Edges iteration

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,8 @@ function update_simplified_system!(
740740
obs_sub[eq.lhs] = eq.rhs
741741
end
742742
# TODO: compute the dependency correctly so that we don't have to do this
743-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
743+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
744+
fast_substitute(state.additional_observed, obs_sub)]
744745

745746
unknowns = Any[v
746747
for (i, v) in enumerate(state.fullvars)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
14951495
@warn errmsg
14961496
end
14971497

1498-
uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)])
1498+
uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys))
14991499

15001500
# TODO: throw on uninitialized arrays
15011501
filter!(x -> !(x isa Symbolics.Arr), uninit)

src/systems/nonlinear/initializesystem.jl

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,26 @@ function SciMLBase.late_binding_update_u0_p(
648648
newu0, newp = promote_u0_p(newu0, newp, t0)
649649

650650
# non-symbolic u0 updates initials...
651-
if !(eltype(u0) <: Pair)
651+
if eltype(u0) <: Pair
652+
syms = []
653+
vals = []
654+
allsyms = all_symbols(sys)
655+
for (k, v) in u0
656+
v === nothing && continue
657+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
658+
if k isa Symbol
659+
k2 = symbol_to_symbolic(sys, k; allsyms)
660+
# if it is returned as-is, there is no match so skip it
661+
k2 === k && continue
662+
k = k2
663+
end
664+
is_parameter(sys, Initial(k)) || continue
665+
push!(syms, Initial(k))
666+
push!(vals, v)
667+
end
668+
newp = setp_oop(sys, syms)(newp, vals)
669+
else
670+
allsyms = nothing
652671
# if `p` is not provided or is symbolic
653672
p === missing || eltype(p) <: Pair || return newu0, newp
654673
(newu0 === nothing || isempty(newu0)) && return newu0, newp
@@ -661,27 +680,30 @@ function SciMLBase.late_binding_update_u0_p(
661680
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
662681
end
663682
newp = meta.set_initial_unknowns!(newp, newu0)
664-
return newu0, newp
665-
end
666-
667-
syms = []
668-
vals = []
669-
allsyms = all_symbols(sys)
670-
for (k, v) in u0
671-
v === nothing && continue
672-
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
673-
if k isa Symbol
674-
k2 = symbol_to_symbolic(sys, k; allsyms)
675-
# if it is returned as-is, there is no match so skip it
676-
k2 === k && continue
677-
k = k2
683+
end
684+
685+
if eltype(p) <: Pair
686+
syms = []
687+
vals = []
688+
if allsyms === nothing
689+
allsyms = all_symbols(sys)
690+
end
691+
for (k, v) in p
692+
v === nothing && continue
693+
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
694+
if k isa Symbol
695+
k2 = symbol_to_symbolic(sys, k; allsyms)
696+
# if it is returned as-is, there is no match so skip it
697+
k2 === k && continue
698+
k = k2
699+
end
700+
is_parameter(sys, Initial(k)) || continue
701+
push!(syms, Initial(k))
702+
push!(vals, v)
678703
end
679-
is_parameter(sys, Initial(k)) || continue
680-
push!(syms, Initial(k))
681-
push!(vals, v)
704+
newp = setp_oop(sys, syms)(newp, vals)
682705
end
683706

684-
newp = setp_oop(sys, syms)(newp, vals)
685707
return newu0, newp
686708
end
687709

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,9 @@ function structural_simplify(sys::OptimizationSystem; split = true, kwargs...)
741741
nlsys = NonlinearSystem(econs, unknowns(sys), parameters(sys); name = :___tmp_nlsystem)
742742
snlsys = structural_simplify(nlsys; fully_determined = false, kwargs...)
743743
obs = observed(snlsys)
744-
subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys))
745744
seqs = equations(snlsys)
745+
trueobs, _ = unhack_observed(obs, seqs)
746+
subs = Dict(eq.lhs => eq.rhs for eq in trueobs)
746747
cons_simplified = similar(cons, length(icons) + length(seqs))
747748
for (i, eq) in enumerate(Iterators.flatten((seqs, icons)))
748749
cons_simplified[i] = fixpoint_sub(eq, subs)

src/systems/parameter_buffer.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ the default behavior).
2929
function MTKParameters(
3030
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
3131
t0 = nothing, substitution_limit = 1000, floatT = nothing,
32-
p_constructor = identity)
32+
p_constructor = identity, fast_path = false)
3333
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
3434
get_index_cache(sys)
3535
else
@@ -50,9 +50,15 @@ function MTKParameters(
5050
is_time_dependent(sys) && add_observed!(sys, u0)
5151
add_parameter_dependencies!(sys, p)
5252

53-
op, missing_unknowns, missing_pars = build_operating_point!(sys,
54-
u0, p, defs, cmap, dvs, ps)
55-
53+
u0map = anydict()
54+
pmap = anydict()
55+
if fast_path
56+
missing_pars = missingvars(p, ps)
57+
op = p
58+
else
59+
op, _, missing_pars = build_operating_point!(sys,
60+
u0, p, defs, cmap, dvs, ps)
61+
end
5662
if t0 !== nothing
5763
op[get_iv(sys)] = t0
5864
end

src/systems/problem_utils.jl

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,11 @@ in `varmap`, it is ignored.
450450
"""
451451
function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
452452
for k in vars
453+
v = get(varmap, k, nothing)
454+
v === nothing && continue
455+
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
453456
haskey(varmap, k) || continue
454-
varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit)
457+
varmap[k] = fixpoint_sub(v, varmap; maxiters = limit)
455458
end
456459
end
457460

@@ -580,15 +583,19 @@ function build_operating_point!(sys::AbstractSystem,
580583
end
581584
end
582585

583-
for k in keys(u0map)
584-
v = fixpoint_sub(u0map[k], neithermap; operator = Symbolics.Operator)
585-
isequal(k, v) && continue
586-
u0map[k] = v
587-
end
588-
for k in keys(pmap)
589-
v = fixpoint_sub(pmap[k], neithermap; operator = Symbolics.Operator)
590-
isequal(k, v) && continue
591-
pmap[k] = v
586+
if !isempty(neithermap)
587+
for (k, v) in u0map
588+
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
589+
v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator)
590+
isequal(k, v) && continue
591+
u0map[k] = v
592+
end
593+
for (k, v) in pmap
594+
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
595+
v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator)
596+
isequal(k, v) && continue
597+
pmap[k] = v
598+
end
592599
end
593600

594601
return op, missing_unknowns, missing_pars
@@ -1036,6 +1043,9 @@ function (siu::SetInitialUnknowns)(p::AbstractVector, u0)
10361043
return p
10371044
end
10381045

1046+
safe_float(x) = x
1047+
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)
1048+
10391049
"""
10401050
$(TYPEDSIGNATURES)
10411051
@@ -1100,7 +1110,8 @@ function maybe_build_initialization_problem(
11001110
if is_time_dependent(sys)
11011111
all_init_syms = Set(all_symbols(initializeprob))
11021112
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
1103-
initializeprobmap = u0_constructor getu(initializeprob, solved_unknowns)
1113+
initializeprobmap = u0_constructor safe_float
1114+
getu(initializeprob, solved_unknowns)
11041115
else
11051116
initializeprobmap = nothing
11061117
end
@@ -1123,20 +1134,24 @@ function maybe_build_initialization_problem(
11231134
update_initializeprob! = ModelingToolkit.update_initializeprob!
11241135
end
11251136

1126-
for p in punknowns
1127-
is_parameter_solvable(p, pmap, defs, guesses) || continue
1128-
get(op, p, missing) === missing || continue
1137+
filter!(punknowns) do p
1138+
is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing
1139+
end
1140+
pvals = getu(initializeprob, punknowns)(initializeprob)
1141+
for (p, pval) in zip(punknowns, pvals)
11291142
p = unwrap(p)
1130-
op[p] = getu(initializeprob, p)(initializeprob)
1143+
op[p] = pval
11311144
if iscall(p) && operation(p) === getindex
11321145
arrp = arguments(p)[1]
1146+
get(op, arrp, nothing) !== missing && continue
11331147
op[arrp] = collect(arrp)
11341148
end
11351149
end
11361150

11371151
if is_time_dependent(sys)
1138-
for v in missing_unknowns
1139-
op[v] = getu(initializeprob, v)(initializeprob)
1152+
uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob)
1153+
for (v, val) in zip(missing_unknowns, uvals)
1154+
op[v] = val
11401155
end
11411156
empty!(missing_unknowns)
11421157
end
@@ -1371,7 +1386,7 @@ function process_SciMLProblem(
13711386
if !(pType <: AbstractArray)
13721387
pType = Array
13731388
end
1374-
p = MTKParameters(sys, op; floatT = floatT, p_constructor)
1389+
p = MTKParameters(sys, op; floatT = floatT, p_constructor, fast_path = true)
13751390
else
13761391
p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType))
13771392
end

0 commit comments

Comments
 (0)