Skip to content

Commit 06d5699

Browse files
authored
Merge branch 'SciML:master' into iss3707
2 parents b936c8b + a0ce384 commit 06d5699

17 files changed

+347
-101
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "10.5.0"
4+
version = "10.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

benchmark/benchmarks.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ModelingToolkitStandardLibrary.Electrical
44
using ModelingToolkitStandardLibrary.Mechanical.Rotational
55
using ModelingToolkitStandardLibrary.Blocks
66
using OrdinaryDiffEqDefault
7+
using ModelingToolkit: t_nounits as t, D_nounits as D
78

89
const SUITE = BenchmarkGroup()
910

@@ -45,12 +46,33 @@ end
4546

4647
@named model = DCMotor()
4748

49+
# first call
50+
mtkcompile(model)
4851
SUITE["mtkcompile"] = @benchmarkable mtkcompile($model)
4952

5053
model = mtkcompile(model)
5154
u0 = unknowns(model) .=> 0.0
5255
tspan = (0.0, 6.0)
53-
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
5456

5557
prob = ODEProblem(model, u0, tspan)
58+
SUITE["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
59+
60+
# first call
61+
init(prob)
5662
SUITE["init"] = @benchmarkable init($prob)
63+
64+
large_param_init = SUITE["large_parameter_init"] = BenchmarkGroup()
65+
66+
N = 25
67+
@variables x(t)[1:N]
68+
@parameters A[1:N, 1:N]
69+
70+
defval = collect(x) * collect(x)'
71+
@mtkcompile model = System(
72+
[D(x) ~ x], t, [x], [A]; defaults = [A => defval], guesses = [A => fill(NaN, N, N)])
73+
74+
u0 = [x => rand(N)]
75+
prob = ODEProblem(model, u0, tspan)
76+
large_param_init["ODEProblem"] = @benchmarkable ODEProblem($model, $u0, $tspan)
77+
78+
large_param_init["init"] = @benchmarkable init($prob)

src/bipartite_graph.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,39 @@ function set_neighbors!(g::BipartiteGraph, i::Integer, new_neighbors)
535535
end
536536
end
537537

538-
function delete_srcs!(g::BipartiteGraph, srcs)
538+
function delete_srcs!(g::BipartiteGraph{I}, srcs; rm_verts = false) where {I}
539539
for s in srcs
540540
set_neighbors!(g, s, ())
541541
end
542+
if rm_verts
543+
old_to_new_idxs = collect(one(I):I(nsrcs(g)))
544+
for s in srcs
545+
old_to_new_idxs[s] = zero(I)
546+
end
547+
offset = zero(I)
548+
for i in eachindex(old_to_new_idxs)
549+
if iszero(old_to_new_idxs[i])
550+
offset += one(I)
551+
continue
552+
end
553+
old_to_new_idxs[i] -= offset
554+
end
555+
556+
if g.badjlist isa AbstractVector
557+
for i in 1:ndsts(g)
558+
for j in eachindex(g.badjlist[i])
559+
g.badjlist[i][j] = old_to_new_idxs[g.badjlist[i][j]]
560+
end
561+
filter!(!iszero, g.badjlist[i])
562+
end
563+
end
564+
deleteat!(g.fadjlist, srcs)
565+
end
542566
g
543567
end
544-
delete_dsts!(g::BipartiteGraph, srcs) = delete_srcs!(invview(g), srcs)
568+
function delete_dsts!(g::BipartiteGraph, srcs; rm_verts = false)
569+
delete_srcs!(invview(g), srcs; rm_verts)
570+
end
545571

546572
###
547573
### Edges iteration

src/problems/initializationproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
3939
for k in keys(op)
4040
has_u0_ics |= is_variable(sys, k) || isdifferential(k) ||
4141
symbolic_type(k) == ArraySymbolic() &&
42-
is_sized_array_symbolic(k) && is_variable(sys, first(collect(k)))
42+
is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k))))
4343
end
4444
if !has_u0_ics && get_initializesystem(sys) !== nothing
4545
isys = get_initializesystem(sys; initialization_eqs, check_units)
@@ -79,7 +79,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const
7979
@warn errmsg
8080
end
8181

82-
uninit = setdiff(unknowns(sys), [unknowns(isys); observables(isys)])
82+
uninit = setdiff(unknowns(sys), unknowns(isys), observables(isys))
8383

8484
# TODO: throw on uninitialized arrays
8585
filter!(x -> !(x isa Symbolics.Arr), uninit)

src/structural_transformation/symbolics_tearing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,8 @@ function update_simplified_system!(
960960
obs_sub[eq.lhs] = eq.rhs
961961
end
962962
# TODO: compute the dependency correctly so that we don't have to do this
963-
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs]
963+
obs = [fast_substitute(observed(sys), obs_sub); solved_eqs;
964+
fast_substitute(state.additional_observed, obs_sub)]
964965

965966
unknown_idxs = filter(
966967
i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars))

src/systems/abstractsystem.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,23 @@ has_equations(::AbstractSystem) = true
805805
806806
Invalidate cached jacobians, etc.
807807
"""
808-
invalidate_cache!(sys::AbstractSystem) = sys
808+
function invalidate_cache!(sys::AbstractSystem)
809+
has_metadata(sys) || return sys
810+
empty!(getmetadata(sys, MutableCacheKey, nothing))
811+
return sys
812+
end
813+
814+
# `::MetadataT` but that is defined later
815+
function refreshed_metadata(meta::Base.ImmutableDict)
816+
newmeta = MetadataT()
817+
for (k, v) in meta
818+
if k === MutableCacheKey
819+
v = MutableCacheT()
820+
end
821+
newmeta = Base.ImmutableDict(newmeta, k => v)
822+
end
823+
return newmeta
824+
end
809825

810826
function Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field}
811827
getfield(obj, field)
@@ -815,6 +831,8 @@ end
815831
args = map(fieldnames(obj)) do fn
816832
if fn in fieldnames(patch)
817833
:(patch.$fn)
834+
elseif fn == :metadata
835+
:($refreshed_metadata(getfield(obj, $(Meta.quot(fn)))))
818836
else
819837
:(getfield(obj, $(Meta.quot(fn))))
820838
end
@@ -2507,7 +2525,15 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
25072525
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
25082526
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
25092527
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
2510-
meta = merge(get_metadata(basesys), get_metadata(sys))
2528+
meta = MetadataT()
2529+
for kvp in get_metadata(basesys)
2530+
kvp[1] == MutableCacheKey && continue
2531+
meta = Base.ImmutableDict(meta, kvp)
2532+
end
2533+
for kvp in get_metadata(sys)
2534+
kvp[1] == MutableCacheKey && continue
2535+
meta = Base.ImmutableDict(meta, kvp)
2536+
end
25112537
syss = union(get_systems(basesys), get_systems(sys))
25122538
args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps)
25132539
kwargs = (observed = obs, continuous_events = cevs,
@@ -2705,7 +2731,9 @@ function process_parameter_equations(sys::AbstractSystem)
27052731
is_sized_array_symbolic(sym) &&
27062732
all(Base.Fix1(is_parameter, sys), collect(sym))
27072733
end
2708-
if !isparameter(eq.lhs)
2734+
# Everything in `varsbuf` is a parameter, so this is a cheap `is_parameter`
2735+
# check.
2736+
if !(eq.lhs in varsbuf)
27092737
throw(ArgumentError("""
27102738
LHS of parameter dependency equation must be a single parameter. Found \
27112739
$(eq.lhs).

src/systems/connectors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,8 @@ function expand_connections(sys::AbstractSystem; tol = 1e-10)
872872
eqs = [equations(sys); ceqs; stream_eqs]
873873
# substitute `instream(..)` expressions with their new values
874874
for i in eachindex(eqs)
875-
eqs[i] = fixpoint_sub(eqs[i], instream_subs; maxiters = length(instream_subs))
875+
eqs[i] = fixpoint_sub(
876+
eqs[i], instream_subs; maxiters = max(length(instream_subs), 10))
876877
end
877878
# get the defaults for domain networks
878879
d_defs = domain_defaults(sys, domain_csets)

0 commit comments

Comments
 (0)