Skip to content

Commit 02ae459

Browse files
Merge pull request #3764 from oameye/complex
fix: enable support for complex ODEProblem again
2 parents b892604 + d90ad90 commit 02ae459

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

src/systems/index_cache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem)
388388
observed_syms_to_timeseries,
389389
dependent_pars_to_timeseries,
390390
disc_buffer_templates,
391-
BufferTemplate(Real, tunable_buffer_size),
392-
BufferTemplate(Real, initials_buffer_size),
391+
BufferTemplate(Number, tunable_buffer_size),
392+
BufferTemplate(Number, initials_buffer_size),
393393
const_buffer_sizes,
394394
nonnumeric_buffer_sizes,
395395
symbol_to_variable

src/systems/problem_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ function float_type_from_varmap(varmap, floatT = Bool)
11851185

11861186
if v isa AbstractArray
11871187
floatT = promote_type(floatT, eltype(v))
1188-
elseif v isa Real
1188+
elseif v isa Number
11891189
floatT = promote_type(floatT, typeof(v))
11901190
end
11911191
end

test/complex.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,31 @@ using Test
1414
end
1515
@named mixed = ComplexModel()
1616
@test length(equations(mixed)) == 2
17+
18+
@testset "Complex ODEProblem" begin
19+
using ModelingToolkit: t_nounits as t, D_nounits as D
20+
21+
vars = @variables x(t) y(t) z(t)
22+
pars = @parameters a b
23+
24+
eqs = [
25+
D(x) ~ y - x,
26+
D(y) ~ -x * z + b * abs(z),
27+
D(z) ~ x * y - a
28+
]
29+
@named modlorenz = System(eqs, t)
30+
sys = mtkcompile(modlorenz)
31+
32+
ic = ModelingToolkit.get_index_cache(sys)
33+
@test ic.tunable_buffer_size.type == Number
34+
35+
u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3)
36+
p = ComplexF64[5.0, 0.1]
37+
dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p))
38+
prob = ODEProblem(sys, dict, (0.0, 1.0))
39+
40+
using OrdinaryDiffEq
41+
sol = solve(prob, Tsit5(), saveat = 0.1)
42+
43+
@test sol.u[1] isa Vector{ComplexF64}
44+
end

0 commit comments

Comments
 (0)