Skip to content

[v9] feat: significantly improve performance of *Problem generation #3771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
end
end

function delete_srcs!(g::BipartiteGraph, srcs)
function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I}
for s in srcs
set_neighbors!(g, s, ())
end
if rm_verts
old_to_new_idxs = collect(one(I):I(nsrcs(g)))
for s in srcs
old_to_new_idxs[s] = zero(I)
end
offset = zero(I)
for i in eachindex(old_to_new_idxs)
if iszero(old_to_new_idxs[i])
offset += one(I)
continue
end
old_to_new_idxs[i] -= offset
end

if g.badjlist isa AbstractVector
for i in 1:ndsts(g)
for j in eachindex(g.badjlist[i])
g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]]
end
filter!(!iszero, g.badjlist[i])
end
end
deleteat!(g.fadjlist, srcs)
end
g
end
delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs)
function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false)
delete_srcs!(invview(g), srcs; rm_verts)
end

###
### Edges iteration
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,8 @@ function update_simplified_system!(
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
fast_substitute(state.additional_observed, obs_sub)]

unknowns = Any[v
for (i, v) in enumerate(state.fullvars)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
@warn errmsg
end

uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)])
uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys))

# TODO: throw on uninitialized arrays
filter!(x -> !(x isa Symbolics.Arr), uninit)
Expand Down
60 changes: 41 additions & 19 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,26 @@ function SciMLBase.late_binding_update_u0_p(
newu0, newp = promote_u0_p(newu0, newp, t0)

# non-symbolic u0 updates initials...
if !(eltype(u0) <: Pair)
if eltype(u0) <: Pair
syms = []
vals = []
allsyms = all_symbols(sys)
for (k, v) in u0
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
end
newp = setp_oop(sys, syms)(newp, vals)
else
allsyms = nothing
# if `p` is not provided or is symbolic
p === missing || eltype(p) <: Pair || return newu0, newp
(newu0 === nothing || isempty(newu0)) && return newu0, newp
Expand All @@ -661,27 +680,30 @@ function SciMLBase.late_binding_update_u0_p(
throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))"))
end
newp = meta.set_initial_unknowns!(newp, newu0)
return newu0, newp
end

syms = []
vals = []
allsyms = all_symbols(sys)
for (k, v) in u0
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end

if eltype(p) <: Pair
syms = []
vals = []
if allsyms === nothing
allsyms = all_symbols(sys)
end
for (k, v) in p
v === nothing && continue
(symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue
if k isa Symbol
k2 = symbol_to_symbolic(sys, k; allsyms)
# if it is returned as-is, there is no match so skip it
k2 === k && continue
k = k2
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
end
is_parameter(sys, Initial(k)) || continue
push!(syms, Initial(k))
push!(vals, v)
newp = setp_oop(sys, syms)(newp, vals)
end

newp = setp_oop(sys, syms)(newp, vals)
return newu0, newp
end

Expand Down
3 changes: 2 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,9 @@ function structural_simplify(sys::OptimizationSystem; split = true, kwargs...)
nlsys = NonlinearSystem(econs, unknowns(sys), parameters(sys); name = :___tmp_nlsystem)
snlsys = structural_simplify(nlsys; fully_determined = false, kwargs...)
obs = observed(snlsys)
subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys))
seqs = equations(snlsys)
trueobs, _ = unhack_observed(obs, seqs)
subs = Dict(eq.lhs => eq.rhs for eq in trueobs)
cons_simplified = similar(cons, length(icons) + length(seqs))
for (i, eq) in enumerate(Iterators.flatten((seqs, icons)))
cons_simplified[i] = fixpoint_sub(eq, subs)
Expand Down
14 changes: 10 additions & 4 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ the default behavior).
function MTKParameters(
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
t0 = nothing, substitution_limit = 1000, floatT = nothing,
p_constructor = identity)
p_constructor = identity, fast_path = false)
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
get_index_cache(sys)
else
Expand All @@ -50,9 +50,15 @@ function MTKParameters(
is_time_dependent(sys) && add_observed!(sys, u0)
add_parameter_dependencies!(sys, p)

op, missing_unknowns, missing_pars = build_operating_point!(sys,
u0, p, defs, cmap, dvs, ps)

u0map = anydict()
pmap = anydict()
if fast_path
missing_pars = missingvars(p, ps)
op = p
else
op, _, missing_pars = build_operating_point!(sys,
u0, p, defs, cmap, dvs, ps)
end
if t0 !== nothing
op[get_iv(sys)] = t0
end
Expand Down
51 changes: 33 additions & 18 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,11 @@ in `varmap`, it is ignored.
"""
function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
for k in vars
v = get(varmap, k, nothing)
v === nothing && continue
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
haskey(varmap, k) || continue
varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit)
varmap[k] = fixpoint_sub(v, varmap; maxiters = limit)
end
end

Expand Down Expand Up @@ -580,15 +583,19 @@ function build_operating_point!(sys::AbstractSystem,
end
end

for k in keys(u0map)
v = fixpoint_sub(u0map[k], neithermap; operator = Symbolics.Operator)
isequal(k, v) && continue
u0map[k] = v
end
for k in keys(pmap)
v = fixpoint_sub(pmap[k], neithermap; operator = Symbolics.Operator)
isequal(k, v) && continue
pmap[k] = v
if !isempty(neithermap)
for (k, v) in u0map
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator)
isequal(k, v) && continue
u0map[k] = v
end
for (k, v) in pmap
symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue
v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator)
isequal(k, v) && continue
pmap[k] = v
end
end

return op, missing_unknowns, missing_pars
Expand Down Expand Up @@ -1036,6 +1043,9 @@ function (siu::SetInitialUnknowns)(p::AbstractVector, u0)
return p
end

safe_float(x) = x
safe_float(x::AbstractArray) = isempty(x) ? x : float(x)

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -1100,7 +1110,8 @@ function maybe_build_initialization_problem(
if is_time_dependent(sys)
all_init_syms = Set(all_symbols(initializeprob))
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns)
initializeprobmap = u0_constructor ∘ safe_float ∘
getu(initializeprob, solved_unknowns)
else
initializeprobmap = nothing
end
Expand All @@ -1123,20 +1134,24 @@ function maybe_build_initialization_problem(
update_initializeprob! = ModelingToolkit.update_initializeprob!
end

for p in punknowns
is_parameter_solvable(p, pmap, defs, guesses) || continue
get(op, p, missing) === missing || continue
filter!(punknowns) do p
is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing
end
pvals = getu(initializeprob, punknowns)(initializeprob)
for (p, pval) in zip(punknowns, pvals)
p = unwrap(p)
op[p] = getu(initializeprob, p)(initializeprob)
op[p] = pval
if iscall(p) && operation(p) === getindex
arrp = arguments(p)[1]
get(op, arrp, nothing) !== missing && continue
op[arrp] = collect(arrp)
end
end

if is_time_dependent(sys)
for v in missing_unknowns
op[v] = getu(initializeprob, v)(initializeprob)
uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob)
for (v, val) in zip(missing_unknowns, uvals)
op[v] = val
end
empty!(missing_unknowns)
end
Expand Down Expand Up @@ -1371,7 +1386,7 @@ function process_SciMLProblem(
if !(pType <: AbstractArray)
pType = Array
end
p = MTKParameters(sys, op; floatT = floatT, p_constructor)
p = MTKParameters(sys, op; floatT = floatT, p_constructor, fast_path = true)
else
p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType))
end
Expand Down
Loading
Loading