Skip to content

Commit cedf59c

Browse files
feat: use LinearProblem for linear SCCs in SCCNonlinearProblem
1 parent 4940dbc commit cedf59c

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

src/problems/sccnonlinearproblem.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010

1111
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
1212
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
13-
eval_expression = false, eval_module = @__MODULE__, cse = true)
13+
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false)
1414
ps = parameters(sys; initial_parameters = true)
1515
rps = reorder_parameters(sys, ps)
1616
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
@@ -39,9 +39,22 @@ end
3939
struct SCCNonlinearFunction{iip} end
4040

4141
function SCCNonlinearFunction{iip}(
42-
sys::System, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
42+
sys::System, _eqs, _dvs, _obs, cachesyms, op; eval_expression = false,
4343
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
4444
ps = parameters(sys; initial_parameters = true)
45+
subsys = System(
46+
_eqs, _dvs, ps; observed = _obs, name = nameof(sys), defaults = defaults(sys))
47+
@set! subsys.parameter_dependencies = parameter_dependencies(sys)
48+
if get_index_cache(sys) !== nothing
49+
@set! subsys.index_cache = subset_unknowns_observed(
50+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
51+
@set! subsys.complete = true
52+
end
53+
# generate linear problem instead
54+
if isaffine(subsys)
55+
return LinearFunction{iip}(
56+
subsys; eval_expression, eval_module, cse, cachesyms, kwargs...)
57+
end
4558
rps = reorder_parameters(sys, ps)
4659

4760
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
@@ -54,14 +67,6 @@ function SCCNonlinearFunction{iip}(
5467
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
5568
f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
5669

57-
subsys = System(_eqs, _dvs, ps; observed = _obs,
58-
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
59-
if get_index_cache(sys) !== nothing
60-
@set! subsys.index_cache = subset_unknowns_observed(
61-
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
62-
@set! subsys.complete = true
63-
end
64-
6570
return NonlinearFunction{iip}(f; sys = subsys)
6671
end
6772

@@ -70,7 +75,7 @@ function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
7075
end
7176

7277
function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = false,
73-
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
78+
eval_module = @__MODULE__, cse = true, u0_constructor = identity, kwargs...) where {iip}
7479
if !iscomplete(sys) || get_tearing_state(sys) === nothing
7580
error("A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`.")
7681
end
@@ -112,7 +117,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
112117
obs = observed(sys)
113118

114119
_, u0, p = process_SciMLProblem(
115-
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, kwargs...)
120+
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, u0_constructor, kwargs...)
116121

117122
explicitfuns = []
118123
nlfuns = []
@@ -223,7 +228,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
223228
get(cachevars, T, [])
224229
end)
225230
f = SCCNonlinearFunction{iip}(
226-
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs...)
231+
sys, _eqs, _dvs, _obs, cachebufsyms, op;
232+
eval_expression, eval_module, cse, kwargs...)
227233
push!(nlfuns, f)
228234
end
229235

@@ -244,7 +250,14 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
244250
for (f, vscc) in zip(nlfuns, var_sccs)
245251
_u0 = SymbolicUtils.Code.create_array(
246252
typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...)
247-
prob = NonlinearProblem(f, _u0, p)
253+
if f isa LinearFunction
254+
symbolic_interface = f.interface
255+
A, b = get_A_b_from_LinearFunction(
256+
sys, f, p; eval_expression, eval_module, u0_constructor)
257+
prob = LinearProblem(A, b, p; u0 = _u0, f = symbolic_interface)
258+
else
259+
prob = NonlinearProblem(f, _u0, p)
260+
end
248261
push!(subprobs, prob)
249262
end
250263

0 commit comments

Comments
 (0)