Skip to content

Commit fafb47b

Browse files
Merge pull request #3620 from aml5600/andrew/fix-nested-cov
fix change of var for sys with no equations
2 parents c0a61ab + 834154c commit fafb47b

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/systems/diffeqs/basic_transformations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ function change_independent_variable(
155155

156156
# Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable:
157157
# e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u)
158-
function transform(ex)
158+
function transform(ex::T) where {T}
159159
# 1) Replace the argument of every function; e.g. f(t) -> f(u(t))
160160
for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable)
161161
is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)?
@@ -175,7 +175,7 @@ function change_independent_variable(
175175
# 3) Set new independent variable
176176
ex = substitute(ex, iv2_of_iv1 => iv2; fold) # set e.g. u(t) -> u everywhere
177177
ex = substitute(ex, iv1 => iv1_of_iv2; fold) # set e.g. t -> t(u) everywhere
178-
return ex
178+
return ex::T
179179
end
180180

181181
# Use the utility function to transform everything in the system!

test/basic_transformations.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,31 @@ end
231231
# compare to analytical solution (x(t) = v*t, y(t) = v*t - g*t^2/2)
232232
@test all(isapprox.(sol[Mx.y], sol[Mx.x - g * (Mx.t_units)^2 / 2]; atol = 1e-10))
233233
end
234+
235+
@testset "Change independent variable, no equations" begin
236+
# make this "look" like the standard library RealInput
237+
@mtkmodel Input begin
238+
@variables begin
239+
u(t)
240+
end
241+
end
242+
@named input_sys = Input()
243+
input_sys = complete(input_sys)
244+
# test no failures
245+
@test change_independent_variable(input_sys, input_sys.u) isa ODESystem
246+
247+
@mtkmodel NestedInput begin
248+
@components begin
249+
in = Input()
250+
end
251+
@variables begin
252+
x(t)
253+
end
254+
@equations begin
255+
D(x) ~ in.u
256+
end
257+
end
258+
@named nested_input_sys = NestedInput()
259+
nested_input_sys = complete(nested_input_sys; flatten = false)
260+
@test change_independent_variable(nested_input_sys, nested_input_sys.x) isa ODESystem
261+
end

0 commit comments

Comments
 (0)