Skip to content

Commit f215ec0

Browse files
fix: fix infinite recursion in full_equations
1 parent 2035e73 commit f215ec0

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ function update_simplified_system!(
687687
unknowns = [unknowns; extra_unknowns]
688688
@set! sys.unknowns = unknowns
689689

690-
obs, subeqs, deps = cse_and_array_hacks(
691-
sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
690+
obs = cse_and_array_hacks(
691+
sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
692692

693693
@set! sys.eqs = neweqs
694694
@set! sys.observed = obs
@@ -790,7 +790,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
790790
not) we first count the number of times the scalarized form of each observed variable
791791
occurs in observed equations (and unknowns if it's split).
792792
"""
793-
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
793+
function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true)
794794
# HACK 1
795795
# mapping of rhs to temporary CSE variable
796796
# `f(...) => tmpvar` in above example
@@ -818,7 +818,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
818818
tempeq = tempvar ~ rhs_arr
819819
rhs_to_tempvar[rhs_arr] = tempvar
820820
push!(obs, tempeq)
821-
push!(subeqs, tempeq)
822821
end
823822

824823
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
@@ -827,10 +826,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
827826
neweq = lhs ~ getindex_wrapper(
828827
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
829828
obs[i] = neweq
830-
subeqi = findfirst(isequal(eq), subeqs)
831-
if subeqi !== nothing
832-
subeqs[subeqi] = neweq
833-
end
834829
end
835830
# end HACK 1
836831

@@ -860,7 +855,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
860855
tempeq = tempvar ~ rhs_arr
861856
rhs_to_tempvar[rhs_arr] = tempvar
862857
push!(obs, tempeq)
863-
push!(subeqs, tempeq)
864858
end
865859
# don't need getindex_wrapper, but do it anyway to know that this
866860
# hack took place
@@ -900,15 +894,8 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
900894
push!(obs_arr_eqs, arrvar ~ rhs)
901895
end
902896
append!(obs, obs_arr_eqs)
903-
append!(subeqs, obs_arr_eqs)
904-
905-
# need to re-sort subeqs
906-
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
907-
908-
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
909-
for i in 1:length(subeqs)]
910897

911-
return obs, subeqs, deps
898+
return obs
912899
end
913900

914901
function is_getindexed_array(rhs)

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,8 @@ Get a dictionary mapping variables eliminated from the system during `mtkcompile
700700
expressions used to calculate them.
701701
"""
702702
function get_substitutions(sys)
703-
Dict([eq.lhs => eq.rhs for eq in observed(sys)])
703+
obs, _ = unhack_observed(observed(sys), equations(sys))
704+
Dict([eq.lhs => eq.rhs for eq in obs])
704705
end
705706

706707
@noinline function throw_missingvars_in_sys(vars)

test/odesystem.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,3 +1560,32 @@ end
15601560
@mtkcompile sys = SysC()
15611561
@test length(unknowns(sys)) == 3
15621562
end
1563+
1564+
@testset "`full_equations` doesn't recurse infinitely" begin
1565+
code = """
1566+
using ModelingToolkit
1567+
using ModelingToolkit: t_nounits as t, D_nounits as D
1568+
@variables x(t)[1:3]=[0,0,1]
1569+
@variables u1(t)=0 u2(t)=0
1570+
y₁, y₂, y₃ = x
1571+
k₁, k₂, k₃ = 1,1,1
1572+
eqs = [
1573+
D(y₁) ~ -k₁*y₁ + k₃*y₂*y₃ + u1
1574+
D(y₂) ~ k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 + u2
1575+
y₁ + y₂ + y₃ ~ 1
1576+
]
1577+
1578+
@named sys = System(eqs, t)
1579+
1580+
inputs = [u1, u2]
1581+
outputs = [y₁, y₂, y₃]
1582+
ss = mtkcompile(sys; inputs)
1583+
full_equations(ss)
1584+
"""
1585+
1586+
cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code`
1587+
proc = run(cmd, stdin, stdout, stderr; wait = false)
1588+
sleep(120)
1589+
@test !process_running(proc)
1590+
kill(proc, Base.SIGKILL)
1591+
end

0 commit comments

Comments
 (0)