Skip to content

Commit a082d8a

Browse files
Merge pull request #3626 from hersle/change_ivar_array
Change independent variable of ODEs with array variables
2 parents 29ae2a7 + 308c39c commit a082d8a

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/systems/diffeqs/basic_transformations.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,25 @@ function change_independent_variable(
153153
@set! sys.eqs = [get_eqs(sys); eqs] # add extra equations we derived
154154
@set! sys.unknowns = [get_unknowns(sys); [iv1, div2_of_iv1]] # add new variables, will be transformed to e.g. t(u) and uˍt(u)
155155

156+
# A utility function that returns whether var (e.g. f(t)) is a function of iv (e.g. t)
157+
function is_function_of(var, iv)
158+
# Peel off outer calls to find the argument of the function of
159+
if iscall(var) && operation(var) === getindex # handle array variables
160+
var = arguments(var)[1] # (f(t))[1] -> f(t)
161+
end
162+
if iscall(var)
163+
var = only(arguments(var)) # e.g. f(t) -> t
164+
return isequal(var, iv)
165+
end
166+
return false
167+
end
168+
156169
# Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable:
157170
# e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u)
158171
function transform(ex::T) where {T}
159172
# 1) Replace the argument of every function; e.g. f(t) -> f(u(t))
160173
for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable)
161-
is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)?
162-
if is_function_of_iv1 && !isequal(var, iv2_of_iv1) # prevent e.g. u(t) -> u(u(t))
174+
if is_function_of(var, iv1) && !isequal(var, iv2_of_iv1) # of the form f(t)? but prevent e.g. u(t) -> u(u(t))
163175
var_of_iv1 = var # e.g. f(t)
164176
var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1) # e.g. f(u(t))
165177
ex = substitute(ex, var_of_iv1 => var_of_iv2_of_iv1; fold)
@@ -207,15 +219,15 @@ function change_independent_variable(
207219
connector_type = get_connector_type(sys)
208220
assertions = Dict(transform(ass) => msg for (ass, msg) in get_assertions(sys))
209221
wascomplete = iscomplete(sys) # save before reconstructing system
222+
wassplit = is_split(sys)
223+
wasflat = isempty(systems)
210224
sys = typeof(sys)( # recreate system with transformed fields
211225
eqs, iv2, unknowns, ps; observed, initialization_eqs,
212226
parameter_dependencies, defaults, guesses, connector_type,
213227
assertions, name = nameof(sys), description = description(sys)
214228
)
215229
sys = compose(sys, systems) # rebuild hierarchical system
216230
if wascomplete
217-
wasflat = isempty(systems)
218-
wassplit = is_split(sys)
219231
sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete
220232
end
221233
return sys

test/basic_transformations.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,20 @@ end
287287
@test all(isapprox.(sol[ss.t], sol[ss.y]; atol = 1e-10))
288288
@test all(sol[ss.x][2:end] .< sol[ss.x][1])
289289
end
290+
291+
@testset "Change independent variable with array variables" begin
292+
@variables x(t) y(t) z(t)[1:2]
293+
eqs = [
294+
D(x) ~ 2,
295+
z ~ ModelingToolkit.scalarize.([sin(y), cos(y)]),
296+
D(y) ~ z[1]^2 + z[2]^2
297+
]
298+
@named sys = ODESystem(eqs, t)
299+
sys = complete(sys)
300+
new_sys = change_independent_variable(sys, sys.x; add_old_diff = true)
301+
ss_new_sys = structural_simplify(new_sys; allow_symbolic = true)
302+
u0 = [new_sys.y => 0.5, new_sys.t => 0.0]
303+
prob = ODEProblem(ss_new_sys, u0, (0.0, 0.5), [])
304+
sol = solve(prob, Tsit5(); reltol = 1e-5)
305+
@test sol[new_sys.y][end] 0.75
306+
end

0 commit comments

Comments
 (0)