Skip to content

Commit 9923652

Browse files
refactor: remove FunctionalAffect
1 parent e907fb9 commit 9923652

File tree

4 files changed

+20
-109
lines changed

4 files changed

+20
-109
lines changed

ext/MTKFMIExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
235235

236236
# instance management callback which deallocates the instance when
237237
# necessary and notifies the FMU of completed integrator steps
238-
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
239-
step_affect = MTK.FunctionalAffect(Returns(nothing), [], [], [])
238+
finalize_affect = MTK.ImperativeAffect(fmiFinalize!; observed = (; wrapper))
239+
step_affect = MTK.ImperativeAffect(Returns((;)))
240240
instance_management_callback = MTK.SymbolicDiscreteCallback(
241241
(t == t - 1), step_affect; finalize = finalize_affect, reinitializealg = SciMLBase.NoInit())
242242

@@ -273,7 +273,7 @@ function MTK.FMIComponent(::Val{Ver}; fmu = nothing, tolerance = 1e-6,
273273
end
274274
initialize_affect = MTK.ImperativeAffect(fmiCSInitialize!; observed = cb_observed,
275275
modified = cb_modified, ctx = _functor)
276-
finalize_affect = MTK.FunctionalAffect(fmiFinalize!, [], [wrapper], [])
276+
finalize_affect = MTK.ImperativeAffect(fmiFinalize!; observed = (; wrapper))
277277
# the callback affect performs the stepping
278278
step_affect = MTK.ImperativeAffect(
279279
fmiCSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor)
@@ -708,15 +708,15 @@ end
708708
"""
709709
$(TYPEDSIGNATURES)
710710
711-
An affect function for use inside a `FunctionalAffect`. This should be triggered at the
711+
An affect function for use inside an `ImperativeAffect`. This should be triggered at the
712712
end of the solve, regardless of whether it succeeded or failed. Expects `p` to be a
713713
1-length array containing the index of the instance wrapper (`FMI2InstanceWrapper` or
714714
`FMI3InstanceWrapper`) in the parameter object.
715715
"""
716-
function fmiFinalize!(integrator, u, p, ctx)
717-
wrapper_idx = p[1]
718-
wrapper = integrator.ps[wrapper_idx]
716+
function fmiFinalize!(m, o, ctx, integrator)
717+
wrapper = o.wrapper
719718
reset_instance!(wrapper)
719+
return (;)
720720
end
721721

722722
"""

src/systems/callbacks.jl

Lines changed: 8 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,7 @@
11
abstract type AbstractCallback end
22

3-
struct FunctionalAffect
4-
f::Any
5-
sts::Vector
6-
sts_syms::Vector{Symbol}
7-
pars::Vector
8-
pars_syms::Vector{Symbol}
9-
discretes::Vector
10-
ctx::Any
11-
end
12-
13-
function FunctionalAffect(f, sts, pars, discretes, ctx = nothing)
14-
# sts & pars contain either pairs: resistor.R => R, or Syms: R
15-
vs = [x isa Pair ? x.first : x for x in sts]
16-
vs_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in sts]
17-
length(vs_syms) == length(unique(vs_syms)) || error("Variables are not unique")
18-
19-
ps = [x isa Pair ? x.first : x for x in pars]
20-
ps_syms = Symbol[x isa Pair ? Symbol(x.second) : getname(x) for x in pars]
21-
length(ps_syms) == length(unique(ps_syms)) || error("Parameters are not unique")
22-
23-
FunctionalAffect(f, vs, vs_syms, ps, ps_syms, discretes, ctx)
24-
end
25-
26-
function FunctionalAffect(; f, sts, pars, discretes, ctx = nothing)
27-
FunctionalAffect(f, sts, pars, discretes, ctx)
28-
end
29-
30-
func(a::FunctionalAffect) = a.f
31-
context(a::FunctionalAffect) = a.ctx
32-
parameters(a::FunctionalAffect) = a.pars
33-
parameters_syms(a::FunctionalAffect) = a.pars_syms
34-
unknowns(a::FunctionalAffect) = a.sts
35-
unknowns_syms(a::FunctionalAffect) = a.sts_syms
36-
discretes(a::FunctionalAffect) = a.discretes
37-
38-
function Base.:(==)(a1::FunctionalAffect, a2::FunctionalAffect)
39-
isequal(a1.f, a2.f) && isequal(a1.sts, a2.sts) && isequal(a1.pars, a2.pars) &&
40-
isequal(a1.sts_syms, a2.sts_syms) && isequal(a1.pars_syms, a2.pars_syms) &&
41-
isequal(a1.ctx, a2.ctx)
42-
end
43-
44-
function Base.hash(a::FunctionalAffect, s::UInt)
45-
s = hash(a.f, s)
46-
s = hash(a.sts, s)
47-
s = hash(a.sts_syms, s)
48-
s = hash(a.pars, s)
49-
s = hash(a.pars_syms, s)
50-
s = hash(a.discretes, s)
51-
hash(a.ctx, s)
52-
end
53-
543
function has_functional_affect(cb)
55-
(affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect)
4+
affects(cb) isa ImperativeAffect
565
end
576

587
struct AffectSystem
@@ -97,7 +46,7 @@ function Base.hash(a::AffectSystem, s::UInt)
9746
hash(aff_to_sys(a), s)
9847
end
9948

100-
function vars!(vars, aff::Union{FunctionalAffect, AffectSystem}; op = Differential)
49+
function vars!(vars, aff::AffectSystem; op = Differential)
10150
for var in Iterators.flatten((unknowns(aff), parameters(aff), discretes(aff)))
10251
vars!(vars, var)
10352
end
@@ -161,7 +110,7 @@ end
161110
###############################
162111
###### Continuous events ######
163112
###############################
164-
const Affect = Union{AffectSystem, FunctionalAffect, ImperativeAffect}
113+
const Affect = Union{AffectSystem, ImperativeAffect}
165114

166115
"""
167116
SymbolicContinuousCallback(eqs::Vector{Equation}, affect = nothing, iv = nothing;
@@ -233,7 +182,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
233182
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
234183

235184
if isnothing(reinitializealg)
236-
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
185+
if any(a -> a isa ImperativeAffect,
237186
[affect, affect_neg, initialize, finalize])
238187
reinitializealg = SciMLBase.CheckInit()
239188
else
@@ -263,8 +212,8 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
263212
end
264213

265214
make_affect(affect::Nothing; kwargs...) = nothing
266-
make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
267-
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
215+
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
216+
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
268217
make_affect(affect::Affect; kwargs...) = affect
269218

270219
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
@@ -446,7 +395,7 @@ struct SymbolicDiscreteCallback <: AbstractCallback
446395
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
447396

448397
if isnothing(reinitializealg)
449-
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
398+
if any(a -> a isa ImperativeAffect,
450399
[affect, initialize, finalize])
451400
reinitializealg = SciMLBase.CheckInit()
452401
else
@@ -498,16 +447,6 @@ end
498447
############################################
499448
########## Namespacing Utilities ###########
500449
############################################
501-
function namespace_affects(affect::FunctionalAffect, s)
502-
FunctionalAffect(func(affect),
503-
renamespace.((s,), unknowns(affect)),
504-
unknowns_syms(affect),
505-
renamespace.((s,), parameters(affect)),
506-
parameters_syms(affect),
507-
renamespace.((s,), discretes(affect)),
508-
context(affect))
509-
end
510-
511450
function namespace_affects(affect::AffectSystem, s)
512451
AffectSystem(renamespace(s, system(affect)),
513452
renamespace.((s,), unknowns(affect)),
@@ -652,36 +591,6 @@ function compile_condition(
652591
return CompiledCondition{is_discrete(cbs)}(fs)
653592
end
654593

655-
"""
656-
Compile user-defined functional affect.
657-
"""
658-
function compile_functional_affect(affect::FunctionalAffect, sys; kwargs...)
659-
dvs = unknowns(sys)
660-
ps = parameters(sys)
661-
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
662-
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
663-
664-
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
665-
p_inds = [(pind = parameter_index(sys, sym)) === nothing ? sym : pind
666-
for sym in parameters(affect)]
667-
else
668-
ps_ind = Dict(reverse(en) for en in enumerate(ps))
669-
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
670-
end
671-
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
672-
# (MTK should keep these symbols)
673-
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>
674-
NamedTuple
675-
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
676-
NamedTuple
677-
678-
let u = u, p = p, user_affect = func(affect), ctx = context(affect)
679-
(integ) -> begin
680-
user_affect(integ, u, p, ctx)
681-
end
682-
end
683-
end
684-
685594
is_discrete(cb::AbstractCallback) = cb isa SymbolicDiscreteCallback
686595
is_discrete(cb::Vector{<:AbstractCallback}) = eltype(cb) isa SymbolicDiscreteCallback
687596

@@ -837,7 +746,7 @@ function compile_affect(
837746
elseif aff isa AffectSystem
838747
f = compile_equational_affect(aff, sys; kwargs...)
839748
wrap_save_discretes(f, save_idxs)
840-
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
749+
elseif aff isa ImperativeAffect
841750
f = compile_functional_affect(aff, sys; kwargs...)
842751
wrap_save_discretes(f, save_idxs)
843752
end

src/systems/imperative_affect.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Where we use Setfield to copy the tuple `m` with a new value for `x`, then retur
2828
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
2929
in the returned tuple, in which case the associated field will not be updated.
3030
"""
31-
@kwdef struct ImperativeAffect
31+
struct ImperativeAffect
3232
f::Any
3333
obs::Vector
3434
obs_syms::Vector{Symbol}
@@ -63,6 +63,9 @@ function ImperativeAffect(
6363
ImperativeAffect(
6464
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
6565
end
66+
function ImperativeAffect(; f, kwargs...)
67+
ImperativeAffect(f; kwargs...)
68+
end
6669

6770
function Base.show(io::IO, mfa::ImperativeAffect)
6871
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")

src/systems/index_cache.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ function IndexCache(sys::AbstractSystem)
117117
affs = [affs]
118118
end
119119
for affect in affs
120-
if affect isa AffectSystem || affect isa FunctionalAffect ||
121-
affect isa ImperativeAffect
120+
if affect isa AffectSystem || affect isa ImperativeAffect
122121
union!(discs, unwrap.(discretes(affect)))
123122
elseif isnothing(affect)
124123
continue

0 commit comments

Comments
 (0)