diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index b6665646c9..8cdb76cca0 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors) end end -function delete_srcs!(g::BipartiteGraph, srcs) +function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I} for s in srcs set_neighbors!(g, s, ()) end + if rm_verts + old_to_new_idxs = collect(one(I):I(nsrcs(g))) + for s in srcs + old_to_new_idxs[s] = zero(I) + end + offset = zero(I) + for i in eachindex(old_to_new_idxs) + if iszero(old_to_new_idxs[i]) + offset += one(I) + continue + end + old_to_new_idxs[i] -= offset + end + + if g.badjlist isa AbstractVector + for i in 1:ndsts(g) + for j in eachindex(g.badjlist[i]) + g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]] + end + filter!(!iszero, g.badjlist[i]) + end + end + deleteat!(g.fadjlist, srcs) + end g end -delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs) +function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false) + delete_srcs!(invview(g), srcs; rm_verts) +end ### ### Edges iteration diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index cc582cd473..244a737705 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -740,7 +740,8 @@ function update_simplified_system!( obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); solved_eqs] + obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; + fast_substitute(state.additional_observed, obs_sub)] unknowns = Any[v for (i, v) in enumerate(state.fullvars) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index d802f49fee..f87c4d0071 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1495,7 +1495,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, @warn errmsg end - uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)]) + uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys)) # TODO: throw on uninitialized arrays filter!(x -> !(x isa Symbolics.Arr), uninit) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 5a8687984f..bd2c7de753 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -648,7 +648,26 @@ function SciMLBase.late_binding_update_u0_p( newu0, newp = promote_u0_p(newu0, newp, t0) # non-symbolic u0 updates initials... - if !(eltype(u0) <: Pair) + if eltype(u0) <: Pair + syms = [] + vals = [] + allsyms = all_symbols(sys) + for (k, v) in u0 + v === nothing && continue + (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue + if k isa Symbol + k2 = symbol_to_symbolic(sys, k; allsyms) + # if it is returned as-is, there is no match so skip it + k2 === k && continue + k = k2 + end + is_parameter(sys, Initial(k)) || continue + push!(syms, Initial(k)) + push!(vals, v) + end + newp = setp_oop(sys, syms)(newp, vals) + else + allsyms = nothing # if `p` is not provided or is symbolic p === missing || eltype(p) <: Pair || return newu0, newp (newu0 === nothing || isempty(newu0)) && return newu0, newp @@ -661,27 +680,30 @@ function SciMLBase.late_binding_update_u0_p( throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(prob.u0))). Got $(typeof(newu0)) of length $(length(newu0))")) end newp = meta.set_initial_unknowns!(newp, newu0) - return newu0, newp - end - - syms = [] - vals = [] - allsyms = all_symbols(sys) - for (k, v) in u0 - v === nothing && continue - (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue - if k isa Symbol - k2 = symbol_to_symbolic(sys, k; allsyms) - # if it is returned as-is, there is no match so skip it - k2 === k && continue - k = k2 + end + + if eltype(p) <: Pair + syms = [] + vals = [] + if allsyms === nothing + allsyms = all_symbols(sys) + end + for (k, v) in p + v === nothing && continue + (symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)) || continue + if k isa Symbol + k2 = symbol_to_symbolic(sys, k; allsyms) + # if it is returned as-is, there is no match so skip it + k2 === k && continue + k = k2 + end + is_parameter(sys, Initial(k)) || continue + push!(syms, Initial(k)) + push!(vals, v) end - is_parameter(sys, Initial(k)) || continue - push!(syms, Initial(k)) - push!(vals, v) + newp = setp_oop(sys, syms)(newp, vals) end - newp = setp_oop(sys, syms)(newp, vals) return newu0, newp end diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index bfe15b62d7..fcb502efdf 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -741,8 +741,9 @@ function structural_simplify(sys::OptimizationSystem; split = true, kwargs...) nlsys = NonlinearSystem(econs, unknowns(sys), parameters(sys); name = :___tmp_nlsystem) snlsys = structural_simplify(nlsys; fully_determined = false, kwargs...) obs = observed(snlsys) - subs = Dict(eq.lhs => eq.rhs for eq in observed(snlsys)) seqs = equations(snlsys) + trueobs, _ = unhack_observed(obs, seqs) + subs = Dict(eq.lhs => eq.rhs for eq in trueobs) cons_simplified = similar(cons, length(icons) + length(seqs)) for (i, eq) in enumerate(Iterators.flatten((seqs, icons))) cons_simplified[i] = fixpoint_sub(eq, subs) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index c3d2a0e831..00742f028c 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -29,7 +29,7 @@ the default behavior). function MTKParameters( sys::AbstractSystem, p, u0 = Dict(); tofloat = false, t0 = nothing, substitution_limit = 1000, floatT = nothing, - p_constructor = identity) + p_constructor = identity, fast_path = false) ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing get_index_cache(sys) else @@ -50,9 +50,15 @@ function MTKParameters( is_time_dependent(sys) && add_observed!(sys, u0) add_parameter_dependencies!(sys, p) - op, missing_unknowns, missing_pars = build_operating_point!(sys, - u0, p, defs, cmap, dvs, ps) - + u0map = anydict() + pmap = anydict() + if fast_path + missing_pars = missingvars(p, ps) + op = p + else + op, _, missing_pars = build_operating_point!(sys, + u0, p, defs, cmap, dvs, ps) + end if t0 !== nothing op[get_iv(sys)] = t0 end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index b893d9ffc6..e4d043d334 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -450,8 +450,11 @@ in `varmap`, it is ignored. """ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) for k in vars + v = get(varmap, k, nothing) + v === nothing && continue + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = fixpoint_sub(varmap[k], varmap; maxiters = limit) + varmap[k] = fixpoint_sub(v, varmap; maxiters = limit) end end @@ -580,15 +583,19 @@ function build_operating_point!(sys::AbstractSystem, end end - for k in keys(u0map) - v = fixpoint_sub(u0map[k], neithermap; operator = Symbolics.Operator) - isequal(k, v) && continue - u0map[k] = v - end - for k in keys(pmap) - v = fixpoint_sub(pmap[k], neithermap; operator = Symbolics.Operator) - isequal(k, v) && continue - pmap[k] = v + if !isempty(neithermap) + for (k, v) in u0map + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue + v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator) + isequal(k, v) && continue + u0map[k] = v + end + for (k, v) in pmap + symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue + v = fixpoint_sub(v, neithermap; operator = Symbolics.Operator) + isequal(k, v) && continue + pmap[k] = v + end end return op, missing_unknowns, missing_pars @@ -1036,6 +1043,9 @@ function (siu::SetInitialUnknowns)(p::AbstractVector, u0) return p end +safe_float(x) = x +safe_float(x::AbstractArray) = isempty(x) ? x : float(x) + """ $(TYPEDSIGNATURES) @@ -1100,7 +1110,8 @@ function maybe_build_initialization_problem( if is_time_dependent(sys) all_init_syms = Set(all_symbols(initializeprob)) solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys)) - initializeprobmap = u0_constructor ∘ getu(initializeprob, solved_unknowns) + initializeprobmap = u0_constructor ∘ safe_float ∘ + getu(initializeprob, solved_unknowns) else initializeprobmap = nothing end @@ -1123,20 +1134,24 @@ function maybe_build_initialization_problem( update_initializeprob! = ModelingToolkit.update_initializeprob! end - for p in punknowns - is_parameter_solvable(p, pmap, defs, guesses) || continue - get(op, p, missing) === missing || continue + filter!(punknowns) do p + is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing + end + pvals = getu(initializeprob, punknowns)(initializeprob) + for (p, pval) in zip(punknowns, pvals) p = unwrap(p) - op[p] = getu(initializeprob, p)(initializeprob) + op[p] = pval if iscall(p) && operation(p) === getindex arrp = arguments(p)[1] + get(op, arrp, nothing) !== missing && continue op[arrp] = collect(arrp) end end if is_time_dependent(sys) - for v in missing_unknowns - op[v] = getu(initializeprob, v)(initializeprob) + uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob) + for (v, val) in zip(missing_unknowns, uvals) + op[v] = val end empty!(missing_unknowns) end @@ -1371,7 +1386,7 @@ function process_SciMLProblem( if !(pType <: AbstractArray) pType = Array end - p = MTKParameters(sys, op; floatT = floatT, p_constructor) + p = MTKParameters(sys, op; floatT = floatT, p_constructor, fast_path = true) else p = p_constructor(better_varmap_to_vars(op, ps; tofloat, container_type = pType)) end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index e0feb0d34d..48214e01a4 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -208,12 +208,19 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} structure::SystemStructure extra_eqs::Vector param_derivative_map::Dict{BasicSymbolic, Any} + original_eqs::Vector{Equation} + """ + Additional user-provided observed equations. The variables calculated here + are not used in the rest of the system. + """ + additional_observed::Vector{Equation} end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) @set! ts.sys.eqs = eqs[ieqs] + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) ts end @@ -266,6 +273,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) iv = length(ivs) == 1 ? ivs[1] : nothing # scalarize array equations, without scalarizing arguments to registered functions eqs = flatten_equations(copy(equations(sys))) + original_eqs = copy(eqs) neqs = length(eqs) dervaridxs = OrderedSet{Int}() var2idx = Dict{Any, Int}() @@ -378,6 +386,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) end end eqs = eqs[eqs_to_retain] + original_eqs = original_eqs[eqs_to_retain] neqs = length(eqs) symbolic_incidence = symbolic_incidence[eqs_to_retain] @@ -386,6 +395,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # depending on order due to NP-completeness of tearing. sortidxs = Base.sortperm(eqs, by = string) eqs = eqs[sortidxs] + original_eqs = original_eqs[sortidxs] symbolic_incidence = symbolic_incidence[sortidxs] end @@ -475,13 +485,116 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem), - Any[], param_derivative_map) + Any[], param_derivative_map, original_eqs, Equation[]) if sys isa DiscreteSystem ts = shift_discrete_system(ts) end return ts end +""" + $(TYPEDSIGNATURES) + +Preemptively identify observed equations in the system and tear them. This happens before +any simplification. The equations torn by this process are ones that are already given in +an explicit form in the system and where the LHS is not present in any other equation of +the system except for other such preempitvely torn equations. +""" +function trivial_tearing!(ts::TearingState) + @assert length(ts.original_eqs) == length(equations(ts)) + # equations that can be trivially torn an observed equations + trivial_idxs = BitSet() + # equations to never check + blacklist = BitSet() + torn_eqs = Equation[] + # variables that have been matched to trivially torn equations + matched_vars = BitSet() + # variable to index in fullvars + var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars)) + + complete!(ts.structure) + var_to_diff = ts.structure.var_to_diff + graph = ts.structure.graph + while true + # track whether we added an equation to the trivial list this iteration + added_equation = false + for (i, eq) in enumerate(ts.original_eqs) + # don't check already torn equations + i in trivial_idxs && continue + i in blacklist && continue + # ensure it is an observed equation matched to a variable in fullvars + vari = get(var_to_idx, eq.lhs, 0) + iszero(vari) && continue + # don't tear irreducible variables + if isirreducible(eq.lhs) + push!(blacklist, i) + continue + end + # if a variable was the LHS of two trivial observed equations, we wouldn't have + # included it in the list. Error if somehow it made it through. + @assert !(vari in matched_vars) + # don't tear differential/shift equations (or differentiated/shifted variables) + var_to_diff[vari] === nothing || continue + invview(var_to_diff)[vari] === nothing || continue + # get the equations that the candidate matched variable is present in, except + # those equations which have already been torn as observed + eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs) + # it should only be present in this equation + length(eqidxs) == 1 || continue + eqi = only(eqidxs) + @assert eqi == i + + # for every variable present in this equation, make sure it isn't _only_ + # present in trivial equations + isvalid = true + for v in 𝑠neighbors(graph, eqi) + v == vari && continue + v in matched_vars && continue + # `> 1` and not `0` because one entry will be this equation (`eqi`) + isvalid &= count(!in(trivial_idxs), 𝑑neighbors(graph, v)) > 1 + isvalid || break + end + isvalid || continue + # skip if the LHS is present in the RHS, since then this isn't explicit + if occursin(eq.lhs, eq.rhs) + push!(blacklist, i) + continue + end + + added_equation = true + push!(trivial_idxs, eqi) + push!(torn_eqs, eq) + push!(matched_vars, vari) + end + + # if we didn't add an equation this iteration, we won't add one next iteration + added_equation || break + end + + deleteat!(var_to_diff.primal_to_diff, matched_vars) + deleteat!(var_to_diff.diff_to_primal, matched_vars) + deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs) + deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs) + delete_srcs!(ts.structure.graph, trivial_idxs; rm_verts = true) + delete_dsts!(ts.structure.graph, matched_vars; rm_verts = true) + if ts.structure.solvable_graph !== nothing + delete_srcs!(ts.structure.solvable_graph, trivial_idxs; rm_verts = true) + delete_dsts!(ts.structure.solvable_graph, matched_vars; rm_verts = true) + end + if ts.structure.var_types !== nothing + deleteat!(ts.structure.var_types, matched_vars) + end + deleteat!(ts.fullvars, matched_vars) + deleteat!(ts.original_eqs, trivial_idxs) + ts.additional_observed = torn_eqs + sys = ts.sys + eqs = copy(get_eqs(sys)) + deleteat!(eqs, trivial_idxs) + @set! sys.eqs = eqs + ts.sys = sys + return ts +end + function lower_order_var(dervar, t) if isdifferential(dervar) diffvar = arguments(dervar)[1] @@ -739,6 +852,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false, else input_idxs = 0:-1 # Empty range end + trivial_tearing!(state) sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...) if check_consistency fully_determined = ModelingToolkit.check_consistency( diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index fd84c346f9..a7c6f4a27e 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1253,9 +1253,9 @@ end @test init(prob3)[x] ≈ 1.0 prob4 = remake(prob; p = [p => 1.0]) test_dummy_initialization_equation(prob4, x) - prob5 = remake(prob; p = [p => missing, q => 2.0]) + prob5 = remake(prob; p = [p => missing, q => 4.0]) @test prob5.f.initialization_data !== nothing - @test init(prob5).ps[p] ≈ 1.0 + @test init(prob5).ps[p] ≈ 2.0 end @testset "Variables provided as symbols" begin