Skip to content

Commit f1ff9fe

Browse files
committed
added: ExtendedKalmanFilter now uses DI.jl
Bonus : it is now allocation-free !
1 parent 14f8a84 commit f1ff9fe

File tree

3 files changed

+50
-57
lines changed

3 files changed

+50
-57
lines changed

src/estimator/execute.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ where ``\mathbf{x̂_0}(k+1)`` is stored in `x̂next0` argument. The method mutat
3636
function signature for conciseness.
3737
"""
3838
function f̂!(x̂next0, û0, estim::StateEstimator, model::SimModel, x̂0, u0, d0)
39-
return f!(x̂next0, û0, model, estim.As, estim.Cs_u, x̂0, u0, d0)
39+
return !(x̂next0, û0, model, estim.As, estim.Cs_u, x̂0, u0, d0)
4040
end
4141

4242
"""

src/estimator/kalman.jl

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -909,10 +909,10 @@ struct ExtendedKalmanFilter{
909909
::Matrix{NT}
910910
::Matrix{NT}
911911
Ĥm ::Matrix{NT}
912-
direct::Bool
913-
corrected::Vector{Bool}
914912
jacobian::JB
915913
linfunc!::LF
914+
direct::Bool
915+
corrected::Vector{Bool}
916916
buffer::StateEstimatorBuffer{NT}
917917
function ExtendedKalmanFilter{NT}(
918918
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; jacobian::JB, linfunc!::LF, direct=true
@@ -935,7 +935,7 @@ struct ExtendedKalmanFilter{
935935
Ĥ, Ĥm = zeros(NT, ny, nx̂), zeros(NT, nym, nx̂)
936936
corrected = [false]
937937
buffer = StateEstimatorBuffer{NT}(nu, nx̂, nym, ny, nd)
938-
return new{NT, SM}(
938+
return new{NT, SM, JB, LF}(
939939
model,
940940
lastu0, x̂op, f̂op, x̂0, P̂,
941941
i_ym, nx̂, nym, nyu, nxs,
@@ -1005,7 +1005,7 @@ function ExtendedKalmanFilter(
10051005
P̂_0 = Hermitian(diagm(NT[σP_0; σPint_u_0; σPint_ym_0].^2), :L)
10061006
= Hermitian(diagm(NT[σQ; σQint_u; σQint_ym ].^2), :L)
10071007
= Hermitian(diagm(NT[σR;].^2), :L)
1008-
linfunc! = get_ekf_linfunc(model, i_ym, nint_u, nint_ym, jacobian)
1008+
linfunc! = get_ekf_linfunc(NT, model, i_ym, nint_u, nint_ym, jacobian)
10091009
return ExtendedKalmanFilter{NT}(
10101010
model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; jacobian, linfunc!, direct
10111011
)
@@ -1024,48 +1024,52 @@ function ExtendedKalmanFilter(
10241024
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; jacobian=AutoForwardDiff(), direct=true
10251025
) where {NT<:Real, SM<:SimModel{NT}}
10261026
P̂_0, Q̂, R̂ = to_mat(P̂_0), to_mat(Q̂), to_mat(R̂)
1027-
linfunc! = get_ekf_linfunc(model, i_ym, nint_u, nint_ym, jacobian)
1027+
linfunc! = get_ekf_linfunc(NT, model, i_ym, nint_u, nint_ym, jacobian)
10281028
return ExtendedKalmanFilter{NT}(
10291029
model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; jacobian, direct, linfunc!
10301030
)
10311031
end
10321032

1033-
function get_ekf_linfunc(model, i_ym, nint_u, nint_ym, jacobian)
1033+
"""
1034+
get_ekf_linfunc(NT, model, i_ym, nint_u, nint_ym, jacobian) -> linfunc!
1035+
1036+
Return the `linfunc!` function that computes the Jacobians of the augmented model.
1037+
1038+
The function has the following signature:
1039+
```
1040+
linfunc!(x̂0next, ŷ0, F̂, Ĥ, backend, x̂0, cst_u0, cst_d0) -> nothing
1041+
```
1042+
#TODO: continue here
1043+
"""
1044+
function get_ekf_linfunc(NT, model, i_ym, nint_u, nint_ym, jacobian)
10341045
As, Cs_u, Cs_y = init_estimstoch(model, i_ym, nint_u, nint_ym)
1035-
function f̂_ekf!(x̂0next, x̂0, model, As, Cs_u, u0, d0, û0)
1036-
return f̂!(x̂next0, û0, model, As, Cs_u, x̂0, u0, d0)
1037-
end
1038-
function ĥ_ekf!(ŷ0, x̂0, model, Cs_y, d0)
1039-
return ĥ!(ŷ0, model, Cs_y, x̂0, d0)
1040-
end
1046+
f̂_ekf!(x̂0next, x̂0, û0, u0, d0) = f̂!(x̂0next, û0, model, As, Cs_u, x̂0, u0, d0)
1047+
ĥ_ekf!(ŷ0, x̂0, d0) = ĥ!(ŷ0, model, Cs_y, x̂0, d0)
10411048
strict = Val(true)
1042-
#TODO: continue here:
1043-
xnext = zeros(NT, nx)
1044-
y = zeros(NT, ny)
1045-
x = zeros(NT, nx)
1046-
u = zeros(NT, nu)
1047-
d = zeros(NT, nd)
1048-
cst_x = Constant(x)
1049-
cst_u = Constant(u)
1050-
cst_d = Constant(d)
1051-
A_prep = prepare_jacobian(f_x!, xnext, backend, x, cst_u, cst_d; strict)
1052-
Bu_prep = prepare_jacobian(f_u!, xnext, backend, u, cst_x, cst_d; strict)
1053-
Bd_prep = prepare_jacobian(f_d!, xnext, backend, d, cst_x, cst_u; strict)
1054-
C_prep = prepare_jacobian(h_x!, y, backend, x, cst_d ; strict)
1055-
Dd_prep = prepare_jacobian(h_d!, y, backend, d, cst_x ; strict)
1056-
function linfunc!(xnext, y, A, Bu, C, Bd, Dd, backend, x, u, d, cst_x, cst_u, cst_d)
1057-
# all the arguments before `x` are mutated in this function
1058-
jacobian!(f_x!, xnext, A, A_prep, backend, x, cst_u, cst_d)
1059-
jacobian!(f_u!, xnext, Bu, Bu_prep, backend, u, cst_x, cst_d)
1060-
jacobian!(f_d!, xnext, Bd, Bd_prep, backend, d, cst_x, cst_u)
1061-
jacobian!(h_x!, y, C, C_prep, backend, x, cst_d)
1062-
jacobian!(h_d!, y, Dd, Dd_prep, backend, d, cst_x)
1049+
nu, ny, nd = model.nu, model.ny, model.nd
1050+
nx̂ = model.nx + size(As, 1)
1051+
x̂0next = zeros(NT, nx̂)
1052+
ŷ0 = zeros(NT, ny)
1053+
x̂0 = zeros(NT, nx̂)
1054+
tmp_û0 = Cache(zeros(NT, nu))
1055+
cst_u0 = Constant(zeros(NT, nu))
1056+
cst_d0 = Constant(zeros(NT, nd))
1057+
F̂_prep = prepare_jacobian(f̂_ekf!, x̂0next, jacobian, x̂0, tmp_û0, cst_u0, cst_d0; strict)
1058+
Ĥ_prep = prepare_jacobian(ĥ_ekf!, ŷ0, jacobian, x̂0, cst_d0; strict)
1059+
# main method to compute both Jacobians, it mutates all args before `backend`:
1060+
function linfunc!(x̂0next, ŷ0, F̂, Ĥ, backend, x̂0, cst_u0, cst_d0)
1061+
jacobian!(f̂_ekf!, x̂0next, F̂, F̂_prep, backend, x̂0, tmp_û0, cst_u0, cst_d0)
1062+
jacobian!(ĥ_ekf!, ŷ0, Ĥ, Ĥ_prep, backend, x̂0, cst_d0)
10631063
return nothing
10641064
end
1065+
# two additional methods to only compute one of the two Jacobians at a time:
1066+
function linfunc!(x̂0next, ŷ0::Nothing, F̂, Ĥ::Nothing, backend, x̂0, cst_u0, cst_d0)
1067+
return jacobian!(f̂_ekf!, x̂0next, F̂, F̂_prep, backend, x̂0, tmp_û0, cst_u0, cst_d0)
1068+
end
1069+
function linfunc!(x̂0next::Nothing, ŷ0, F̂::Nothing, Ĥ, backend, x̂0, _ , cst_d0)
1070+
return jacobian!(ĥ_ekf!, ŷ0, Ĥ, Ĥ_prep, backend, x̂0, cst_d0)
1071+
end
10651072
return linfunc!
1066-
1067-
1068-
10691073
end
10701074

10711075
"""
@@ -1075,15 +1079,10 @@ Do the same but for the [`ExtendedKalmanFilter`](@ref).
10751079
"""
10761080
function correct_estimate!(estim::ExtendedKalmanFilter, y0m, d0)
10771081
model, x̂0 = estim.model, estim.x̂0
1078-
ŷ0 = estim.buffer.
1079-
1080-
1081-
ĥAD! = (ŷ0, x̂0) -> ĥ!(ŷ0, estim, model, x̂0, d0)
1082-
ForwardDiff.jacobian!(estim.Ĥ, ĥAD!, ŷ0, x̂0)
1083-
1084-
1082+
cst_d0 = Constant(d0)
1083+
ŷ0, Ĥ = estim.buffer.ŷ, estim.
1084+
estim.linfunc!(nothing, ŷ0, nothing, Ĥ, estim.jacobian, x̂0, nothing, cst_d0)
10851085
estim.Ĥm .= @views estim.Ĥ[estim.i_ym, :]
1086-
10871086
return correct_estimate_kf!(estim, y0m, d0, estim.Ĥm)
10881087
end
10891088

@@ -1129,22 +1128,16 @@ prediction step equations are provided below. The correction step is skipped if
11291128
function update_estimate!(estim::ExtendedKalmanFilter{NT}, y0m, d0, u0) where NT<:Real
11301129
model, x̂0 = estim.model, estim.x̂0
11311130
nx̂, nu = estim.nx̂, model.nu
1131+
cst_u0, cst_d0 = Constant(u0), Constant(d0)
11321132
if !estim.direct
1133-
ŷ0 = estim.buffer.
1134-
ĥAD! = (ŷ0, x̂0) -> ĥ!(ŷ0, estim, model, x̂0, d0)
1135-
ForwardDiff.jacobian!(estim.Ĥ, ĥAD!, ŷ0, x̂0)
1133+
ŷ0, Ĥ = estim.buffer.ŷ, estim.
1134+
estim.linfunc!(nothing, ŷ0, nothing, Ĥ, estim.jacobian, x̂0, nothing, cst_d0)
11361135
estim.Ĥm .= @views estim.Ĥ[estim.i_ym, :]
11371136
correct_estimate_kf!(estim, y0m, d0, estim.Ĥm)
11381137
end
11391138
x̂0corr = estim.x̂0
1140-
# concatenate x̂0next and û0 vectors to allows û0 vector with dual numbers for AD:
1141-
# TODO: remove this allocation using estim.buffer
1142-
x̂0nextû = Vector{NT}(undef, nx̂ + nu)
1143-
f̂AD! = (x̂0nextû, x̂0corr) -> @views f̂!(
1144-
x̂0nextû[1:nx̂], x̂0nextû[nx̂+1:end], estim, model, x̂0corr, u0, d0
1145-
)
1146-
ForwardDiff.jacobian!(estim.F̂_û, f̂AD!, x̂0nextû, x̂0corr)
1147-
estim.F̂ .= @views estim.F̂_û[1:estim.nx̂, :]
1139+
x̂0next, F̂ = estim.buffer.x̂, estim.
1140+
estim.linfunc!(x̂0next, nothing, F̂, nothing, estim.jacobian, x̂0corr, cst_u0, cst_d0)
11481141
return predict_estimate_kf!(estim, u0, d0, estim.F̂)
11491142
end
11501143

src/model/linearization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function get_linearization_func(NT, f!, h!, nu, nx, ny, nd, p, backend)
3232
C_prep = prepare_jacobian(h_x!, y, backend, x, cst_d ; strict)
3333
Dd_prep = prepare_jacobian(h_d!, y, backend, d, cst_x ; strict)
3434
function linfunc!(xnext, y, A, Bu, C, Bd, Dd, backend, x, u, d, cst_x, cst_u, cst_d)
35-
# all the arguments before `x` are mutated in this function
35+
# all the arguments before `backend` are mutated in this function
3636
jacobian!(f_x!, xnext, A, A_prep, backend, x, cst_u, cst_d)
3737
jacobian!(f_u!, xnext, Bu, Bu_prep, backend, u, cst_x, cst_d)
3838
jacobian!(f_d!, xnext, Bd, Bd_prep, backend, d, cst_x, cst_u)

0 commit comments

Comments
 (0)