Skip to content

Commit 20f2b2d

Browse files
committed
wip: NonLinMPC wip hessian backend
1 parent b0cb917 commit 20f2b2d

File tree

4 files changed

+124
-67
lines changed

4 files changed

+124
-67
lines changed

src/ModelPredictiveControl.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ using RecipesBase
88
using ProgressLogging
99

1010
using DifferentiationInterface: ADTypes.AbstractADType, AutoForwardDiff, AutoSparse
11-
using DifferentiationInterface: gradient!, jacobian!, prepare_gradient, prepare_jacobian
11+
using DifferentiationInterface: prepare_gradient, prepare_jacobian, prepare_hessian
12+
using DifferentiationInterface: gradient!, jacobian!, hessian!
1213
using DifferentiationInterface: value_and_gradient!, value_and_jacobian!
14+
using DifferentiationInterface: value_gradient_and_hessian!
1315
using DifferentiationInterface: Constant, Cache
1416
using SparseConnectivityTracer: TracerSparsityDetector
1517
using SparseMatrixColorings: GreedyColoringAlgorithm, sparsity_pattern

src/controller/nonlinmpc.jl

Lines changed: 82 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ struct NonLinMPC{
1313
SE<:StateEstimator,
1414
TM<:TranscriptionMethod,
1515
JM<:JuMP.GenericModel,
16-
GB<:AbstractADType,
17-
JB<:AbstractADType,
16+
GB<:Union{Nothing, AbstractADType},
1817
HB<:Union{Nothing, AbstractADType},
18+
JB<:AbstractADType,
1919
PT<:Any,
2020
JEfunc<:Function,
2121
GCfunc<:Function
@@ -27,8 +27,8 @@ struct NonLinMPC{
2727
optim::JM
2828
con::ControllerConstraint{NT, GCfunc}
2929
gradient::GB
30+
hessian ::HB
3031
jacobian::JB
31-
hessian::HB
3232
::Vector{NT}
3333
::Vector{NT}
3434
Hp::Int
@@ -65,15 +65,15 @@ struct NonLinMPC{
6565
function NonLinMPC{NT}(
6666
estim::SE,
6767
Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt, JE::JEfunc, gc!::GCfunc, nc, p::PT,
68-
transcription::TM, optim::JM, gradient::GB, jacobian::JB, hessian::HB
68+
transcription::TM, optim::JM, gradient::GB, hessian::HB, jacobian::JB,
6969
) where {
7070
NT<:Real,
7171
SE<:StateEstimator,
7272
TM<:TranscriptionMethod,
7373
JM<:JuMP.GenericModel,
74-
GB<:AbstractADType,
75-
JB<:AbstractADType,
74+
GB<:Union{Nothing, AbstractADType},
7675
HB<:Union{Nothing, AbstractADType},
76+
JB<:AbstractADType,
7777
PT<:Any,
7878
JEfunc<:Function,
7979
GCfunc<:Function,
@@ -110,9 +110,9 @@ struct NonLinMPC{
110110
nZ̃ = get_nZ(estim, transcription, Hp, Hc) +
111111
= zeros(NT, nZ̃)
112112
buffer = PredictiveControllerBuffer(estim, transcription, Hp, Hc, nϵ)
113-
mpc = new{NT, SE, TM, JM, GB, JB, HB, PT, JEfunc, GCfunc}(
113+
mpc = new{NT, SE, TM, JM, GB, HB, JB, PT, JEfunc, GCfunc}(
114114
estim, transcription, optim, con,
115-
gradient, jacobian, hessian,
115+
gradient, hessian, jacobian,
116116
Z̃, ŷ,
117117
Hp, Hc, nϵ,
118118
weights,
@@ -205,12 +205,14 @@ This controller allocates memory at each time step for the optimization.
205205
- `transcription=SingleShooting()` : a [`TranscriptionMethod`](@ref) for the optimization.
206206
- `optim=JuMP.Model(Ipopt.Optimizer)` : nonlinear optimizer used in the predictive
207207
controller, provided as a [`JuMP.Model`](@extref) object (default to [`Ipopt`](https://github.com/jump-dev/Ipopt.jl) optimizer).
208-
- `gradient=AutoForwardDiff()` : an `AbstractADType` backend for the gradient of the objective
209-
function, see [`DifferentiationInterface` doc](@extref DifferentiationInterface List).
208+
- `hessian=nothing` : an `AbstractADType` backend for the Hessian of the objective function
209+
(see [`DifferentiationInterface` doc](@extref DifferentiationInterface List)), or
210+
`nothing` for the LBFGS approximation provided by `optim` (details in Extended Help).
211+
- `gradient=isnothing(hessian) ? AutoForwardDiff() : nothing` : an `AbstractADType` backend
212+
for the gradient of the objective function (see `hessian` for the options), or `nothing`
213+
to retrieve lower-order derivatives from `hessian`.
210214
- `jacobian=default_jacobian(transcription)` : an `AbstractADType` backend for the Jacobian
211-
of the nonlinear constraints, see `gradient` above for the options (default in Extended Help).
212-
- `hessian=nothing` : an `AbstractADType` backend for the Hessian of the objective function,
213-
see `gradient` above for the options, use `nothing` for the LBFGS approximation of `optim`.
215+
of the nonlinear constraints (see `hessian` for the options, defaults in Extended Help).
214216
- additional keyword arguments are passed to [`UnscentedKalmanFilter`](@ref) constructor
215217
(or [`SteadyKalmanFilter`](@ref), for [`LinModel`](@ref)).
216218
@@ -264,16 +266,16 @@ NonLinMPC controller with a sample time Ts = 10.0 s, Ipopt optimizer, UnscentedK
264266
exception: if `transcription` is not a [`SingleShooting`](@ref), the `jacobian` argument
265267
defaults to this [sparse backend](@extref DifferentiationInterface AutoSparse-object):
266268
```julia
267-
AutoSparse(
269+
sparseAD = AutoSparse(
268270
AutoForwardDiff();
269271
sparsity_detector = TracerSparsityDetector(),
270272
coloring_algorithm = GreedyColoringAlgorithm()
271273
)
272274
```
273-
Also, the `hessian` argument defaults to `nothing` meaning the built-in second-order
274-
approximation of `solver`. Otherwise, a sparse backend like above is recommended to test
275-
different `hessian` methods. Optimizers generally benefit from exact derivatives like AD.
276-
However, the [`NonLinModel`](@ref) state-space functions must be compatible with this
275+
Also, the `hessian` argument defaults to `nothing` meaning the LBFGS approximation of
276+
`optim`. Otherwise, a sparse backend like above is recommended to test a different
277+
`hessian` method. In general, optimizers benefit from exact derivatives like AD.
278+
However, the [`NonLinModel`](@ref) state-space functions must be compatible with this
277279
feature. See [`JuMP` documentation](@extref JuMP Common-mistakes-when-writing-a-user-defined-operator)
278280
for common mistakes when writing these functions.
279281
@@ -299,16 +301,16 @@ function NonLinMPC(
299301
p = model.p,
300302
transcription::TranscriptionMethod = DEFAULT_NONLINMPC_TRANSCRIPTION,
301303
optim::JuMP.GenericModel = JuMP.Model(DEFAULT_NONLINMPC_OPTIMIZER, add_bridges=false),
302-
gradient::AbstractADType = DEFAULT_NONLINMPC_GRADIENT,
304+
hessian ::Union{Nothing, AbstractADType} = nothing,
305+
gradient::Union{Nothing, AbstractADType} = isnothing(hessian) ? DEFAULT_NONLINMPC_GRADIENT : nothing,
303306
jacobian::AbstractADType = default_jacobian(transcription),
304-
hessian::Union{Nothing, AbstractADType} = nothing,
305307
kwargs...
306308
)
307309
estim = UnscentedKalmanFilter(model; kwargs...)
308310
return NonLinMPC(
309311
estim;
310312
Hp, Hc, Mwt, Nwt, Lwt, Cwt, Ewt, JE, gc, nc, p, M_Hp, N_Hc, L_Hp,
311-
transcription, optim, gradient, jacobian, hessian
313+
transcription, optim, gradient, hessian, jacobian
312314
)
313315
end
314316

@@ -331,16 +333,16 @@ function NonLinMPC(
331333
p = model.p,
332334
transcription::TranscriptionMethod = DEFAULT_NONLINMPC_TRANSCRIPTION,
333335
optim::JuMP.GenericModel = JuMP.Model(DEFAULT_NONLINMPC_OPTIMIZER, add_bridges=false),
334-
gradient::AbstractADType = DEFAULT_NONLINMPC_GRADIENT,
336+
hessian ::Union{Nothing, AbstractADType} = nothing,
337+
gradient::Union{Nothing, AbstractADType} = isnothing(hessian) ? DEFAULT_NONLINMPC_GRADIENT : nothing,
335338
jacobian::AbstractADType = default_jacobian(transcription),
336-
hessian::Union{Nothing, AbstractADType} = nothing,
337339
kwargs...
338340
)
339341
estim = SteadyKalmanFilter(model; kwargs...)
340342
return NonLinMPC(
341343
estim;
342344
Hp, Hc, Mwt, Nwt, Lwt, Cwt, Ewt, JE, gc, nc, p, M_Hp, N_Hc, L_Hp,
343-
transcription, optim, gradient, jacobian, hessian
345+
transcription, optim, gradient, hessian, jacobian
344346
)
345347
end
346348

@@ -387,9 +389,9 @@ function NonLinMPC(
387389
p = estim.model.p,
388390
transcription::TranscriptionMethod = DEFAULT_NONLINMPC_TRANSCRIPTION,
389391
optim::JuMP.GenericModel = JuMP.Model(DEFAULT_NONLINMPC_OPTIMIZER, add_bridges=false),
390-
gradient::AbstractADType = DEFAULT_NONLINMPC_GRADIENT,
392+
hessian ::Union{Nothing, AbstractADType} = nothing,
393+
gradient::Union{Nothing, AbstractADType} = isnothing(hessian) ? DEFAULT_NONLINMPC_GRADIENT : nothing,
391394
jacobian::AbstractADType = default_jacobian(transcription),
392-
hessian::Union{Nothing, AbstractADType} = nothing,
393395
) where {
394396
NT<:Real,
395397
SE<:StateEstimator{NT}
@@ -403,7 +405,7 @@ function NonLinMPC(
403405
gc! = get_mutating_gc(NT, gc)
404406
return NonLinMPC{NT}(
405407
estim, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt, JE, gc!, nc, p,
406-
transcription, optim, gradient, jacobian, hessian
408+
transcription, optim, gradient, hessian, jacobian
407409
)
408410
end
409411

@@ -551,10 +553,12 @@ function init_optimization!(mpc::NonLinMPC, model::SimModel, optim::JuMP.Generic
551553
JuMP.set_attribute(optim, "nlp_scaling_max_gradient", 10.0/C)
552554
end
553555
end
556+
validate_backends(mpc.gradient, mpc.hessian)
554557
Jfunc, ∇Jfunc!, ∇²Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs! = get_optim_functions(
555558
mpc, optim
556559
)
557-
@operator(optim, J, nZ̃, Jfunc, ∇Jfunc!)
560+
Jargs = isnothing(∇²Jfunc!) ? (Jfunc, ∇Jfunc!) : (Jfunc, ∇Jfunc!, ∇²Jfunc!)
561+
@operator(optim, J, nZ̃, Jargs...)
558562
@objective(optim, Min, J(Z̃var...))
559563
init_nonlincon!(mpc, model, transcription, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!)
560564
set_nonlincon!(mpc, model, transcription, optim)
@@ -591,7 +595,7 @@ Inspired from: [User-defined operators with vector outputs](@extref JuMP User-de
591595
function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT<:Real
592596
# ----------- common cache for Jfunc, gfuncs and geqfuncs ----------------------------
593597
model = mpc.estim.model
594-
grad, jac = mpc.gradient, mpc.jacobian
598+
grad, hess, jac = mpc.gradient, mpc.hessian, mpc.jacobian
595599
nu, ny, nx̂, nϵ, nk = model.nu, model.ny, mpc.estim.nx̂, mpc.nϵ, model.nk
596600
Hp, Hc = mpc.Hp, mpc.Hc
597601
ng, nc, neq = length(mpc.con.i_g), mpc.con.nc, mpc.con.neq
@@ -613,32 +617,58 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
613617
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
614618
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
615619
end
616-
Z̃_∇J = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
617-
∇J_context = (
620+
Z̃_J = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
621+
context_J = (
618622
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
619623
Cache(Û0), Cache(K0), Cache(X̂0),
620624
Cache(gc), Cache(g), Cache(geq),
621625
)
622-
∇J_prep = prepare_gradient(Jfunc!, grad, Z̃_∇J, ∇J_context...; strict)
623-
∇J = Vector{JNT}(undef, nZ̃)
624-
function update_objective!(J, ∇J, Z̃, Z̃arg)
626+
if !isnothing(grad)
627+
prep_∇J = prepare_gradient(Jfunc!, grad, Z̃_J, context_J...; strict)
628+
else
629+
prep_∇J = nothing
630+
end
631+
if !isnothing(hess)
632+
prep_∇²J = prepare_hessian(Jfunc!, hess, Z̃_J, context_J...; strict)
633+
display(sparsity_pattern(prep_∇²J))
634+
else
635+
prep_∇²J = nothing
636+
end
637+
∇J = Vector{JNT}(undef, nZ̃)
638+
∇²J = init_diffmat(JNT, hess, prep_∇²J, nZ̃, nZ̃)
639+
640+
641+
642+
function update_objective!(J, ∇J, Z̃, Z̃arg, hess::Nothing, grad::AbstractADType)
643+
if isdifferent(Z̃arg, Z̃)
644+
Z̃ .= Z̃arg
645+
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J...)
646+
end
647+
end
648+
function update_objective!(J, ∇J, Z̃, Z̃arg, hess::AbstractADType, grad::Nothing)
625649
if isdifferent(Z̃arg, Z̃)
626650
Z̃ .= Z̃arg
627-
J[], _ = value_and_gradient!(Jfunc!, ∇J, ∇J_prep, grad, Z̃_∇J, ∇J_context...)
651+
J[], _ = value_gradient_and_hessian!(
652+
Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃, context_J...
653+
)
654+
#display(∇J)
655+
#display(∇²J)
656+
#println(∇²J)
628657
end
629-
end
658+
end
659+
630660
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
631-
update_objective!(J, ∇J, Z̃_∇J, Z̃arg)
661+
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
632662
return J[]::T
633663
end
634664
∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
635665
function (Z̃arg)
636-
update_objective!(J, ∇J, Z̃_∇J, Z̃arg)
666+
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
637667
return ∇J[begin]
638668
end
639669
else # multivariate syntax (see JuMP.@operator doc):
640670
function (∇Jarg::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
641-
update_objective!(J, ∇J, Z̃_∇J, Z̃arg)
671+
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
642672
return ∇Jarg .= ∇J
643673
end
644674
end
@@ -648,27 +678,27 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
648678
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
649679
return g
650680
end
651-
Z̃_∇g = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
652-
∇g_context = (
681+
Z̃_g = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
682+
context_g = (
653683
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
654684
Cache(Û0), Cache(K0), Cache(X̂0),
655685
Cache(gc), Cache(geq),
656686
)
657687
# temporarily enable all the inequality constraints for sparsity detection:
658688
mpc.con.i_g[1:end-nc] .= true
659-
∇g_prep = prepare_jacobian(gfunc!, g, jac, Z̃_∇g, ∇g_context...; strict)
689+
∇g_prep = prepare_jacobian(gfunc!, g, jac, Z̃_g, context_g...; strict)
660690
mpc.con.i_g[1:end-nc] .= false
661691
∇g = init_diffmat(JNT, jac, ∇g_prep, nZ̃, ng)
662692
function update_con!(g, ∇g, Z̃, Z̃arg)
663693
if isdifferent(Z̃arg, Z̃)
664694
Z̃ .= Z̃arg
665-
value_and_jacobian!(gfunc!, g, ∇g, ∇g_prep, jac, Z̃, ∇g_context...)
695+
value_and_jacobian!(gfunc!, g, ∇g, ∇g_prep, jac, Z̃, context_g...)
666696
end
667697
end
668698
gfuncs = Vector{Function}(undef, ng)
669699
for i in eachindex(gfuncs)
670700
gfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
671-
update_con!(g, ∇g, Z̃_∇g, Z̃arg)
701+
update_con!(g, ∇g, Z̃_g, Z̃arg)
672702
return g[i]::T
673703
end
674704
gfuncs[i] = gfunc_i
@@ -677,12 +707,12 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
677707
for i in eachindex(∇gfuncs!)
678708
∇gfuncs_i! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
679709
function (Z̃arg::T) where T<:Real
680-
update_con!(g, ∇g, Z̃_∇g, Z̃arg)
710+
update_con!(g, ∇g, Z̃_g, Z̃arg)
681711
return ∇g[i, begin]
682712
end
683713
else # multivariate syntax (see JuMP.@operator doc):
684714
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
685-
update_con!(g, ∇g, Z̃_∇g, Z̃arg)
715+
update_con!(g, ∇g, Z̃_g, Z̃arg)
686716
return ∇g_i .= @views ∇g[i, :]
687717
end
688718
end
@@ -693,24 +723,24 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
693723
update_predictions!(ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, K0, X̂0, gc, g, geq, mpc, Z̃)
694724
return geq
695725
end
696-
Z̃_∇geq = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
697-
∇geq_context = (
726+
Z̃_geq = fill(myNaN, nZ̃) # NaN to force update_predictions! at first call
727+
context_geq = (
698728
Cache(ΔŨ), Cache(x̂0end), Cache(Ue), Cache(Ŷe), Cache(U0), Cache(Ŷ0),
699729
Cache(Û0), Cache(K0), Cache(X̂0),
700730
Cache(gc), Cache(g)
701731
)
702-
∇geq_prep = prepare_jacobian(geqfunc!, geq, jac, Z̃_∇geq, ∇geq_context...; strict)
732+
∇geq_prep = prepare_jacobian(geqfunc!, geq, jac, Z̃_geq, context_geq...; strict)
703733
∇geq = init_diffmat(JNT, jac, ∇geq_prep, nZ̃, neq)
704734
function update_con_eq!(geq, ∇geq, Z̃, Z̃arg)
705735
if isdifferent(Z̃arg, Z̃)
706736
Z̃ .= Z̃arg
707-
value_and_jacobian!(geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, ∇geq_context...)
737+
value_and_jacobian!(geqfunc!, geq, ∇geq, ∇geq_prep, jac, Z̃, context_geq...)
708738
end
709739
end
710740
geqfuncs = Vector{Function}(undef, neq)
711741
for i in eachindex(geqfuncs)
712742
geqfunc_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
713-
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃arg)
743+
update_con_eq!(geq, ∇geq, Z̃_geq, Z̃arg)
714744
return geq[i]::T
715745
end
716746
geqfuncs[i] = geqfunc_i
@@ -721,7 +751,7 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
721751
# constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
722752
∇geqfuncs_i! =
723753
function (∇geq_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
724-
update_con_eq!(geq, ∇geq, Z̃_∇geq, Z̃arg)
754+
update_con_eq!(geq, ∇geq, Z̃_geq, Z̃arg)
725755
return ∇geq_i .= @views ∇geq[i, :]
726756
end
727757
∇geqfuncs![i] = ∇geqfuncs_i!

0 commit comments

Comments
 (0)