Skip to content

Commit 2c04b7b

Browse files
feat: preemptively tear some trivial equations in mtkcompile
1 parent a501375 commit 2c04b7b

File tree

2 files changed

+118
-2
lines changed

2 files changed

+118
-2
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ function update_simplified_system!(
960960
obs_sub[eq.lhs] = eq.rhs
961961
end
962962
# TODO: compute the dependency correctly so that we don't have to do this
963-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
963+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
964+
fast_substitute(state.additional_observed, obs_sub)]
964965

965966
unknown_idxs = filter(
966967
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))

src/systems/systemstructure.jl

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,19 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
211+
original_eqs::Vector{Equation}
212+
"""
213+
Additional user-provided observed equations. The variables calculated here
214+
are not used in the rest of the system.
215+
"""
216+
additional_observed::Vector{Equation}
211217
end
212218

213219
TransformationState(sys::AbstractSystem) = TearingState(sys)
214220
function system_subset(ts::TearingState, ieqs::Vector{Int})
215221
eqs = equations(ts)
216222
@set! ts.sys.eqs = eqs[ieqs]
223+
@set! ts.original_eqs = ts.original_eqs[ieqs]
217224
@set! ts.structure = system_subset(ts.structure, ieqs)
218225
ts
219226
end
@@ -276,6 +283,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
276283
iv = length(ivs) == 1 ? ivs[1] : nothing
277284
# flatten array equations
278285
eqs = flatten_equations(equations(sys))
286+
original_eqs = copy(eqs)
279287
neqs = length(eqs)
280288
param_derivative_map = Dict{BasicSymbolic, Any}()
281289
# * Scalarize unknowns
@@ -320,6 +328,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
320328
varsbuf = Set()
321329
eqs_to_retain = trues(length(eqs))
322330
for (i, eq) in enumerate(eqs)
331+
_eq = eq
323332
if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential &&
324333
isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv)
325334
# parameter derivatives are opted out by specifying `D(p) ~ missing`, but
@@ -415,6 +424,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
415424
end
416425
end
417426
eqs = eqs[eqs_to_retain]
427+
original_eqs = original_eqs[eqs_to_retain]
418428
neqs = length(eqs)
419429
symbolic_incidence = symbolic_incidence[eqs_to_retain]
420430

@@ -423,6 +433,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
423433
# depending on order due to NP-completeness of tearing.
424434
sortidxs = Base.sortperm(eqs, by = string)
425435
eqs = eqs[sortidxs]
436+
original_eqs = original_eqs[sortidxs]
426437
symbolic_incidence = symbolic_incidence[sortidxs]
427438
end
428439

@@ -516,11 +527,114 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
516527
ts = TearingState(sys, fullvars,
517528
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
518529
complete(graph), nothing, var_types, false),
519-
Any[], param_derivative_map)
530+
Any[], param_derivative_map, original_eqs, Equation[])
520531

521532
return ts
522533
end
523534

535+
"""
536+
$(TYPEDSIGNATURES)
537+
538+
Preemptively identify observed equations in the system and tear them. This happens before
539+
any simplification. The equations torn by this process are ones that are already given in
540+
an explicit form in the system and where the LHS is not present in any other equation of
541+
the system except for other such preempitvely torn equations.
542+
"""
543+
function trivial_tearing!(ts::TearingState)
544+
@assert length(ts.original_eqs) == length(equations(ts))
545+
# equations that can be trivially torn an observed equations
546+
trivial_idxs = BitSet()
547+
# equations to never check
548+
blacklist = BitSet()
549+
torn_eqs = Equation[]
550+
# variables that have been matched to trivially torn equations
551+
matched_vars = BitSet()
552+
# variable to index in fullvars
553+
var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars))
554+
555+
complete!(ts.structure)
556+
var_to_diff = ts.structure.var_to_diff
557+
graph = ts.structure.graph
558+
while true
559+
# track whether we added an equation to the trivial list this iteration
560+
added_equation = false
561+
for (i, eq) in enumerate(ts.original_eqs)
562+
# don't check already torn equations
563+
i in trivial_idxs && continue
564+
i in blacklist && continue
565+
# ensure it is an observed equation matched to a variable in fullvars
566+
vari = get(var_to_idx, eq.lhs, 0)
567+
iszero(vari) && continue
568+
# don't tear irreducible variables
569+
if isirreducible(eq.lhs)
570+
push!(blacklist, i)
571+
continue
572+
end
573+
# if a variable was the LHS of two trivial observed equations, we wouldn't have
574+
# included it in the list. Error if somehow it made it through.
575+
@assert !(vari in matched_vars)
576+
# don't tear differential/shift equations (or differentiated/shifted variables)
577+
var_to_diff[vari] === nothing || continue
578+
invview(var_to_diff)[vari] === nothing || continue
579+
# get the equations that the candidate matched variable is present in, except
580+
# those equations which have already been torn as observed
581+
eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs)
582+
# it should only be present in this equation
583+
length(eqidxs) == 1 || continue
584+
eqi = only(eqidxs)
585+
@assert eqi == i
586+
587+
# for every variable present in this equation, make sure it isn't _only_
588+
# present in trivial equations
589+
isvalid = true
590+
for v in 𝑠neighbors(graph, eqi)
591+
v == vari && continue
592+
v in matched_vars && continue
593+
# `> 1` and not `0` because one entry will be this equation (`eqi`)
594+
isvalid &= count(!in(trivial_idxs), 𝑑neighbors(graph, v)) > 1
595+
isvalid || break
596+
end
597+
isvalid || continue
598+
# skip if the LHS is present in the RHS, since then this isn't explicit
599+
if occursin(eq.lhs, eq.rhs)
600+
push!(blacklist, i)
601+
continue
602+
end
603+
604+
added_equation = true
605+
push!(trivial_idxs, eqi)
606+
push!(torn_eqs, eq)
607+
push!(matched_vars, vari)
608+
end
609+
610+
# if we didn't add an equation this iteration, we won't add one next iteration
611+
added_equation || break
612+
end
613+
614+
deleteat!(var_to_diff.primal_to_diff, matched_vars)
615+
deleteat!(var_to_diff.diff_to_primal, matched_vars)
616+
deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs)
617+
deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs)
618+
delete_srcs!(ts.structure.graph, trivial_idxs; rm_verts = true)
619+
delete_dsts!(ts.structure.graph, matched_vars; rm_verts = true)
620+
if ts.structure.solvable_graph !== nothing
621+
delete_srcs!(ts.structure.solvable_graph, trivial_idxs; rm_verts = true)
622+
delete_dsts!(ts.structure.solvable_graph, matched_vars; rm_verts = true)
623+
end
624+
if ts.structure.var_types !== nothing
625+
deleteat!(ts.structure.var_types, matched_vars)
626+
end
627+
deleteat!(ts.fullvars, matched_vars)
628+
deleteat!(ts.original_eqs, trivial_idxs)
629+
ts.additional_observed = torn_eqs
630+
sys = ts.sys
631+
eqs = copy(get_eqs(sys))
632+
deleteat!(eqs, trivial_idxs)
633+
@set! sys.eqs = eqs
634+
ts.sys = sys
635+
return ts
636+
end
637+
524638
function lower_order_var(dervar, t)
525639
if isdifferential(dervar)
526640
diffvar = arguments(dervar)[1]
@@ -753,6 +867,7 @@ function _mtkcompile!(state::TearingState; simplify = false,
753867
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
754868
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
755869
end
870+
trivial_tearing!(state)
756871
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
757872
if check_consistency
758873
fully_determined = ModelingToolkit.check_consistency(

0 commit comments

Comments
 (0)