Skip to content

Commit 4940dbc

Browse files
feat: add LinearFunction as a late-binding for creating LinearProblem
1 parent 12c2bab commit 4940dbc

File tree

1 file changed

+56
-26
lines changed

1 file changed

+56
-26
lines changed

src/problems/linearproblem.jl

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
struct LinearFunction{iip, I} <: SciMLBase.AbstractSciMLFunction{iip}
2+
interface::I
3+
A::AbstractMatrix
4+
b::AbstractVector
5+
end
6+
7+
function LinearFunction{iip}(
8+
sys::System; expression = Val{false}, check_compatibility = true,
9+
sparse = false, eval_expression = false, eval_module = @__MODULE__,
10+
checkbounds = false, cse = true, kwargs...) where {iip}
11+
check_complete(sys, LinearProblem)
12+
check_compatibility && check_compatible_system(LinearProblem, sys)
13+
14+
A, b = calculate_A_b(sys; sparse)
15+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
16+
eval_module, checkbounds, cse, kwargs...)
17+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
18+
eval_module, checkbounds, cse, kwargs...)
19+
observedfun = ObservedFunctionCache(
20+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
21+
cse)
22+
23+
if expression == Val{true}
24+
symbolic_interface = quote
25+
update_A = $update_A
26+
update_b = $update_b
27+
sys = $sys
28+
observedfun = $observedfun
29+
$(SciMLBase.SymbolicLinearInterface)(
30+
update_A, update_b, sys, observedfun, nothing)
31+
end
32+
else
33+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
34+
update_A, update_b, sys, observedfun, nothing)
35+
end
36+
37+
return LinearFunction{iip, typeof(symbolic_interface)}(symbolic_interface, A, b)
38+
end
39+
140
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
241
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
342
end
@@ -14,8 +53,8 @@ function SciMLBase.LinearProblem{iip}(
1453
check_complete(sys, LinearProblem)
1554
check_compatibility && check_compatible_system(LinearProblem, sys)
1655

17-
_, u0, p = process_SciMLProblem(
18-
EmptySciMLFunction{iip}, sys, op; check_length, expression,
56+
f, u0, p = process_SciMLProblem(
57+
LinearFunction{iip}, sys, op; check_length, expression,
1958
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
2059
kwargs...)
2160

@@ -32,25 +71,21 @@ function SciMLBase.LinearProblem{iip}(
3271
u0_eltype = something(u0_eltype, floatT)
3372

3473
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
74+
symbolic_interface = f.interface
75+
A, b = get_A_b_from_LinearFunction(
76+
sys, f, p; eval_expression, eval_module, expression, u0_constructor)
3577

36-
A, b = calculate_A_b(sys; sparse)
37-
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
38-
eval_module, checkbounds, cse, kwargs...)
39-
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
40-
eval_module, checkbounds, cse, kwargs...)
41-
observedfun = ObservedFunctionCache(
42-
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
43-
cse)
78+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
79+
args = (; A, b, p)
4480

81+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
82+
end
83+
84+
function get_A_b_from_LinearFunction(
85+
sys::System, f::LinearFunction, p; eval_expression = false,
86+
eval_module = @__MODULE__, expression = Val{false}, u0_constructor = identity)
87+
@unpack A, b, interface = f
4588
if expression == Val{true}
46-
symbolic_interface = quote
47-
update_A = $update_A
48-
update_b = $update_b
49-
sys = $sys
50-
observedfun = $observedfun
51-
$(SciMLBase.SymbolicLinearInterface)(
52-
update_A, update_b, sys, observedfun, nothing)
53-
end
5489
get_A = build_explicit_observed_function(
5590
sys, A; param_only = true, eval_expression, eval_module)
5691
if sparse
@@ -61,16 +96,11 @@ function SciMLBase.LinearProblem{iip}(
6196
A = u0_constructor(get_A(p))
6297
b = u0_constructor(get_b(p))
6398
else
64-
symbolic_interface = SciMLBase.SymbolicLinearInterface(
65-
update_A, update_b, sys, observedfun, nothing)
66-
A = u0_constructor(update_A(p))
67-
b = u0_constructor(update_b(p))
99+
A = u0_constructor(interface.update_A!(p))
100+
b = u0_constructor(interface.update_b!(p))
68101
end
69102

70-
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
71-
args = (; A, b, p)
72-
73-
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
103+
return A, b
74104
end
75105

76106
# For remake

0 commit comments

Comments
 (0)