diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index c66c562e9c..2ce1c7cffa 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem) observed_syms_to_timeseries, dependent_pars_to_timeseries, disc_buffer_templates, - BufferTemplate(Real, tunable_buffer_size), - BufferTemplate(Real, initials_buffer_size), + BufferTemplate(Number, tunable_buffer_size), + BufferTemplate(Number, initials_buffer_size), const_buffer_sizes, nonnumeric_buffer_sizes, symbol_to_variable diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index f785de798f..5f5e63aa70 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -985,7 +985,7 @@ end $(TYPEDEF) A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`. -Returns the value to use for the `u0` of the problem. +Returns the value to use for the `u0` of the problem. # Fields @@ -1185,7 +1185,7 @@ function float_type_from_varmap(varmap, floatT = Bool) if v isa AbstractArray floatT = promote_type(floatT, eltype(v)) - elseif v isa Real + elseif v isa Number floatT = promote_type(floatT, typeof(v)) end end @@ -1451,7 +1451,7 @@ function check_inputmap_keys(sys, op) end const BAD_KEY_MESSAGE = """ - Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. + Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned. The following keys are invalid: """ diff --git a/test/complex.jl b/test/complex.jl index 69cc22c985..e30ebb177e 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -14,3 +14,31 @@ using Test end @named mixed = ComplexModel() @test length(equations(mixed)) == 2 + +@testset "Complex ODEProblem" begin + using ModelingToolkit: t_nounits as t, D_nounits as D + + vars = @variables x(t) y(t) z(t) + pars = @parameters a b + + eqs = [ + D(x) ~ y - x, + D(y) ~ -x * z + b * abs(z), + D(z) ~ x * y - a + ] + @named modlorenz = System(eqs, t) + sys = mtkcompile(modlorenz) + + ic = ModelingToolkit.get_index_cache(sys) + @test ic.tunable_buffer_size.type == Number + + u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3) + p = ComplexF64[5.0, 0.1] + dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p)) + prob = ODEProblem(sys, dict, (0.0, 1.0)) + + using OrdinaryDiffEq + sol = solve(prob, Tsit5(), saveat = 0.1) + + @test sol.u[1] isa Vector{ComplexF64} +end