Skip to content

Commit 8216332

Browse files
feat: preemptively tear some trivial equations in mtkcompile
1 parent 39c62de commit 8216332

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,8 @@ function update_simplified_system!(
740740
obs_sub[eq.lhs] = eq.rhs
741741
end
742742
# TODO: compute the dependency correctly so that we don't have to do this
743-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
743+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
744+
fast_substitute(state.additional_observed, obs_sub)]
744745

745746
unknowns = Any[v
746747
for (i, v) in enumerate(state.fullvars)

src/systems/systemstructure.jl

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,19 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
208208
structure::SystemStructure
209209
extra_eqs::Vector
210210
param_derivative_map::Dict{BasicSymbolic, Any}
211+
original_eqs::Vector{Equation}
212+
"""
213+
Additional user-provided observed equations. The variables calculated here
214+
are not used in the rest of the system.
215+
"""
216+
additional_observed::Vector{Equation}
211217
end
212218

213219
TransformationState(sys::AbstractSystem) = TearingState(sys)
214220
function system_subset(ts::TearingState, ieqs::Vector{Int})
215221
eqs = equations(ts)
216222
@set! ts.sys.eqs = eqs[ieqs]
223+
@set! ts.original_eqs = ts.original_eqs[ieqs]
217224
@set! ts.structure = system_subset(ts.structure, ieqs)
218225
ts
219226
end
@@ -266,6 +273,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
266273
iv = length(ivs) == 1 ? ivs[1] : nothing
267274
# scalarize array equations, without scalarizing arguments to registered functions
268275
eqs = flatten_equations(copy(equations(sys)))
276+
original_eqs = copy(eqs)
269277
neqs = length(eqs)
270278
dervaridxs = OrderedSet{Int}()
271279
var2idx = Dict{Any, Int}()
@@ -378,6 +386,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
378386
end
379387
end
380388
eqs = eqs[eqs_to_retain]
389+
original_eqs = original_eqs[eqs_to_retain]
381390
neqs = length(eqs)
382391
symbolic_incidence = symbolic_incidence[eqs_to_retain]
383392

@@ -386,6 +395,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
386395
# depending on order due to NP-completeness of tearing.
387396
sortidxs = Base.sortperm(eqs, by = string)
388397
eqs = eqs[sortidxs]
398+
original_eqs = original_eqs[sortidxs]
389399
symbolic_incidence = symbolic_incidence[sortidxs]
390400
end
391401

@@ -475,13 +485,116 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
475485
ts = TearingState(sys, fullvars,
476486
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
477487
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
478-
Any[], param_derivative_map)
488+
Any[], param_derivative_map, original_eqs, Equation[])
479489
if sys isa DiscreteSystem
480490
ts = shift_discrete_system(ts)
481491
end
482492
return ts
483493
end
484494

495+
"""
496+
$(TYPEDSIGNATURES)
497+
498+
Preemptively identify observed equations in the system and tear them. This happens before
499+
any simplification. The equations torn by this process are ones that are already given in
500+
an explicit form in the system and where the LHS is not present in any other equation of
501+
the system except for other such preempitvely torn equations.
502+
"""
503+
function trivial_tearing!(ts::TearingState)
504+
@assert length(ts.original_eqs) == length(equations(ts))
505+
# equations that can be trivially torn an observed equations
506+
trivial_idxs = BitSet()
507+
# equations to never check
508+
blacklist = BitSet()
509+
torn_eqs = Equation[]
510+
# variables that have been matched to trivially torn equations
511+
matched_vars = BitSet()
512+
# variable to index in fullvars
513+
var_to_idx = Dict{Any, Int}(ts.fullvars .=> eachindex(ts.fullvars))
514+
515+
complete!(ts.structure)
516+
var_to_diff = ts.structure.var_to_diff
517+
graph = ts.structure.graph
518+
while true
519+
# track whether we added an equation to the trivial list this iteration
520+
added_equation = false
521+
for (i, eq) in enumerate(ts.original_eqs)
522+
# don't check already torn equations
523+
i in trivial_idxs && continue
524+
i in blacklist && continue
525+
# ensure it is an observed equation matched to a variable in fullvars
526+
vari = get(var_to_idx, eq.lhs, 0)
527+
iszero(vari) && continue
528+
# don't tear irreducible variables
529+
if isirreducible(eq.lhs)
530+
push!(blacklist, i)
531+
continue
532+
end
533+
# if a variable was the LHS of two trivial observed equations, we wouldn't have
534+
# included it in the list. Error if somehow it made it through.
535+
@assert !(vari in matched_vars)
536+
# don't tear differential/shift equations (or differentiated/shifted variables)
537+
var_to_diff[vari] === nothing || continue
538+
invview(var_to_diff)[vari] === nothing || continue
539+
# get the equations that the candidate matched variable is present in, except
540+
# those equations which have already been torn as observed
541+
eqidxs = setdiff(𝑑neighbors(graph, vari), trivial_idxs)
542+
# it should only be present in this equation
543+
length(eqidxs) == 1 || continue
544+
eqi = only(eqidxs)
545+
@assert eqi == i
546+
547+
# for every variable present in this equation, make sure it isn't _only_
548+
# present in trivial equations
549+
isvalid = true
550+
for v in 𝑠neighbors(graph, eqi)
551+
v == vari && continue
552+
v in matched_vars && continue
553+
# `> 1` and not `0` because one entry will be this equation (`eqi`)
554+
isvalid &= count(!in(trivial_idxs), 𝑑neighbors(graph, v)) > 1
555+
isvalid || break
556+
end
557+
isvalid || continue
558+
# skip if the LHS is present in the RHS, since then this isn't explicit
559+
if occursin(eq.lhs, eq.rhs)
560+
push!(blacklist, i)
561+
continue
562+
end
563+
564+
added_equation = true
565+
push!(trivial_idxs, eqi)
566+
push!(torn_eqs, eq)
567+
push!(matched_vars, vari)
568+
end
569+
570+
# if we didn't add an equation this iteration, we won't add one next iteration
571+
added_equation || break
572+
end
573+
574+
deleteat!(var_to_diff.primal_to_diff, matched_vars)
575+
deleteat!(var_to_diff.diff_to_primal, matched_vars)
576+
deleteat!(ts.structure.eq_to_diff.primal_to_diff, trivial_idxs)
577+
deleteat!(ts.structure.eq_to_diff.diff_to_primal, trivial_idxs)
578+
delete_srcs!(ts.structure.graph, trivial_idxs; rm_verts = true)
579+
delete_dsts!(ts.structure.graph, matched_vars; rm_verts = true)
580+
if ts.structure.solvable_graph !== nothing
581+
delete_srcs!(ts.structure.solvable_graph, trivial_idxs; rm_verts = true)
582+
delete_dsts!(ts.structure.solvable_graph, matched_vars; rm_verts = true)
583+
end
584+
if ts.structure.var_types !== nothing
585+
deleteat!(ts.structure.var_types, matched_vars)
586+
end
587+
deleteat!(ts.fullvars, matched_vars)
588+
deleteat!(ts.original_eqs, trivial_idxs)
589+
ts.additional_observed = torn_eqs
590+
sys = ts.sys
591+
eqs = copy(get_eqs(sys))
592+
deleteat!(eqs, trivial_idxs)
593+
@set! sys.eqs = eqs
594+
ts.sys = sys
595+
return ts
596+
end
597+
485598
function lower_order_var(dervar, t)
486599
if isdifferential(dervar)
487600
diffvar = arguments(dervar)[1]
@@ -739,6 +852,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
739852
else
740853
input_idxs = 0:-1 # Empty range
741854
end
855+
trivial_tearing!(state)
742856
sys, mm = ModelingToolkit.alias_elimination!(state; kwargs...)
743857
if check_consistency
744858
fully_determined = ModelingToolkit.check_consistency(

0 commit comments

Comments
 (0)