Skip to content

Commit bc87875

Browse files
Merge pull request #3698 from AayushSabharwal/as/concrete-getu
[v9] hotfix: fix new implementation of `concrete_getu`
2 parents 8cc641f + e9b27f9 commit bc87875

File tree

7 files changed

+25
-18
lines changed

7 files changed

+25
-18
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: IntegrationTest
22
on:
33
push:
4-
branches: [master]
4+
branches: [master, 'backport-v9']
55
tags: [v*]
66
pull_request:
77
paths-ignore:

.github/workflows/ReleaseTest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: ReleaseTest
22
on:
33
push:
4-
branches: [master]
4+
branches: [master, 'backport-v9']
55
tags: [v*]
66
pull_request:
77
paths-ignore:

.github/workflows/Tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ on:
55
branches:
66
- master
77
- 'release-'
8+
- 'backport-v9'
89
paths-ignore:
910
- 'docs/**'
1011
push:
1112
branches:
1213
- master
14+
- 'backport-v9'
1315
paths-ignore:
1416
- 'docs/**'
1517

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ Generates a function that computes the observed value(s) `ts` in the system `sys
476476
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
477477
- `mkarray`: only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
478478
- `cse = true`: Whether to use Common Subexpression Elimination (CSE) to generate a more efficient function.
479+
- `wrap_delays = is_dde(sys)`: Whether to add an argument for the history function and use
480+
it to calculate all delayed variables.
479481
480482
## Returns
481483
@@ -514,7 +516,8 @@ function build_explicit_observed_function(sys, ts;
514516
op = Operator,
515517
throw = true,
516518
cse = true,
517-
mkarray = nothing)
519+
mkarray = nothing,
520+
wrap_delays = is_dde(sys))
518521
is_tuple = ts isa Tuple
519522
if is_tuple
520523
ts = collect(ts)
@@ -600,14 +603,15 @@ function build_explicit_observed_function(sys, ts;
600603
p_end = length(dvs) + length(inputs) + length(ps)
601604
fns = build_function_wrapper(
602605
sys, ts, args...; p_start, p_end, filter_observed = obsfilter,
603-
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse)
606+
output_type, mkarray, try_namespaced = true, expression = Val{true}, cse,
607+
wrap_delays)
604608
if fns isa Tuple
605609
if expression
606610
return return_inplace ? fns : fns[1]
607611
end
608612
oop, iip = eval_or_rgf.(fns; eval_expression, eval_module)
609613
f = GeneratedFunctionWrapper{(
610-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
614+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
611615
oop, iip)
612616
return return_inplace ? (f, f) : f
613617
else
@@ -616,7 +620,7 @@ function build_explicit_observed_function(sys, ts;
616620
end
617621
f = eval_or_rgf(fns; eval_expression, eval_module)
618622
f = GeneratedFunctionWrapper{(
619-
p_start + is_dde(sys), length(args) - length(ps) + 1 + is_dde(sys), is_split(sys))}(
623+
p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}(
620624
f, nothing)
621625
return f
622626
end

src/systems/problem_utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,9 @@ end
631631
ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f)
632632

633633
function (ow::ObservedWrapper{true})(prob)
634-
ow.f(state_values(prob), parameter_values(prob), current_time(prob))
634+
# Edge case for steady state problems
635+
t = applicable(current_time, prob) ? current_time(prob) : Inf
636+
ow.f(state_values(prob), parameter_values(prob), t)
635637
end
636638

637639
function (ow::ObservedWrapper{false})(prob)
@@ -649,7 +651,7 @@ function. It does NOT work for solutions.
649651
"""
650652
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
651653
@nospecialize
652-
obsfn = SymbolicIndexingInterface.observed(indp, syms)
654+
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
653655
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
654656
end
655657

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ Keyword arguments:
10831083
`available_vars` will not be searched for in the observed equations.
10841084
"""
10851085
function observed_equations_used_by(sys::AbstractSystem, exprs;
1086-
involved_vars = vars(exprs; op = Union{Shift, Differential}), obs = observed(sys), available_vars = [])
1086+
involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = [])
10871087
obsvars = getproperty.(obs, :lhs)
10881088
graph = observed_dependency_graph(obs)
10891089
if !(available_vars isa Set)

test/odesystem.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ using OrdinaryDiffEq, Sundials
55
using DiffEqBase, SparseArrays
66
using StaticArrays
77
using Test
8-
using SymbolicUtils: issym
8+
using SymbolicUtils.Code
9+
using SymbolicUtils: Sym, issym
910
using ForwardDiff
1011
using ModelingToolkit: value
1112
using ModelingToolkit: t_nounits as t, D_nounits as D
13+
using Symbolics
14+
using Symbolics: unwrap
15+
using DiffEqBase: isinplace
1216

1317
# Define some variables
1418
@parameters σ ρ β
@@ -607,13 +611,6 @@ sys = complete(sys)
607611
@test_throws Any ODEFunction(sys)
608612

609613
@testset "Preface tests" begin
610-
using OrdinaryDiffEq
611-
using Symbolics
612-
using DiffEqBase: isinplace
613-
using ModelingToolkit
614-
using SymbolicUtils.Code
615-
using SymbolicUtils: Sym
616-
617614
c = [0]
618615
function f(c, du::AbstractVector{Float64}, u::AbstractVector{Float64}, p, t::Float64)
619616
c .= [c[1] + 1]
@@ -656,7 +653,9 @@ sys = complete(sys)
656653

657654
@named sys = ODESystem(eqs, t, us, ps; defaults = defs, preface = preface)
658655
sys = complete(sys)
659-
prob = ODEProblem(sys, [], (0.0, 1.0))
656+
# don't build initializeprob because it will use preface in other functions and
657+
# affect `c`
658+
prob = ODEProblem(sys, [], (0.0, 1.0); build_initializeprob = false)
660659
sol = solve(prob, Euler(); dt = 0.1)
661660

662661
@test c[1] == length(sol)

0 commit comments

Comments
 (0)