diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index dcae2c08f..f70fb3a77 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -213,6 +213,7 @@ function is_discrete_expression(indp, expr) length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries() end +# These are the two main documented user-facing interpolation API functions (out-of-place and in-place versions) function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} if t isa IndexedClock @@ -225,9 +226,12 @@ function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing if t isa IndexedClock t = canonicalize_indexed_clock(t, sol) end - sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) + sol(v, t, deriv, idxs, continuity) end +# Below are many internal dispatches for different combinations of arguments to the main API +# TODO: could use a clever rewrite, since a lot of reused code has accumulated + function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Nothing, continuity) where {deriv} sol.interp(t, idxs, deriv, sol.prob.p, continuity) @@ -365,6 +369,43 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, return DiffEqArray(u, t, p, sol; discretes) end +function (sol::AbstractODESolution)( + v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv}, + idxs::Union{Nothing, Integer, AbstractArray{<:Integer}}, continuity) where {deriv} + return sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) +end +function (sol::AbstractODESolution)( + v, t::Union{Number, AbstractVector{<:Number}}, ::Type{deriv}, idxs, + continuity) where {deriv} + if idxs isa AbstractArray && any(idx -> idx == NotSymbolic(), symbolic_type.(idxs)) || + !(idxs isa AbstractArray) && symbolic_type(idxs) == NotSymbolic() + error("Incorrect specification of `idxs`") + end + error_if_observed_derivative(sol, idxs, deriv) + p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + getter = getsym(sol, idxs) # TODO: breaks type inference and allocates + if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs) + u = zeros(eltype(sol), size(sol)[1]) + if t isa AbstractVector + for ti in eachindex(t) + sol.interp(u, t[ti], nothing, deriv, p, continuity) + state = ProblemState(; u = u, p = p, t = t[ti]) + if eltype(v) <: Number + v[ti] = getter(state) + else + v[ti] .= getter(state) + end + end + else # t isa Number + sol.interp(u, t, nothing, deriv, p, continuity) + state = ProblemState(; u = u, p = p, t = t) + v .= getter(state) + end + return v + end + error("In-place interpolation with discretes is not implemented.") +end + struct DDESolutionHistoryWrapper{T} sol::T end diff --git a/test/downstream/solution_interface.jl b/test/downstream/solution_interface.jl index b59424dd2..00790e40a 100644 --- a/test/downstream/solution_interface.jl +++ b/test/downstream/solution_interface.jl @@ -1,7 +1,7 @@ using ModelingToolkit, OrdinaryDiffEq, RecursiveArrayTools, StochasticDiffEq, Test using StochasticDiffEq using SymbolicIndexingInterface -using ModelingToolkit: t_nounits as t, D_nounits as D +using ModelingToolkit: observed, t_nounits as t, D_nounits as D using Plots: Plots, plot ### Tests on non-layered model (everything should work). ### @@ -148,6 +148,35 @@ sol9 = sol(0.0:1.0:10.0, idxs = 2) sol10 = sol(0.1, idxs = 2) @test sol10 isa Real +# in-place interpolation with single (unknown) symbolic index +ts = 0.0:0.1:10.0 +out = zeros(eltype(sol), size(ts)) +idxs = unknowns(sys)[1] +@test sol(out, ts; idxs) == sol(ts; idxs) +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + +# in-place interpolation with single (observed) symbolic index +idxs = observed(sys)[1].lhs +@test sol(out, ts; idxs) == sol(ts; idxs) +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + +# in-place interpolation with multiple (unknown+observed) symbolic indices +idxs = [unknowns(sys)[1], observed(sys)[1].lhs] +out = [zeros(eltype(sol), size(idxs)) for _ in eachindex(ts)] +@test sol(out, ts; idxs) == sol(ts; idxs).u +@test (@allocated sol(out, ts; idxs)) < (@allocated sol(ts; idxs)) +@test_nowarn @inferred sol(out, ts; idxs) + +# same as above, but with one time value only +@test sol(out[1], ts[1]; idxs) == sol(ts[1]; idxs) +#@test (@allocated sol(out[1], ts[1]; idxs)) < (@allocated sol(ts[1]; idxs)) # TODO: reduce allocations and fix +@test_nowarn @inferred sol(out[1], ts[1]; idxs) + +idxs = [unknowns(sys)[1], 1] +@test_throws "Incorrect specification of `idxs`" sol(out, ts; idxs) + @testset "Plot idxs" begin @variables x(t) y(t) @parameters p