Skip to content

Commit d29673d

Browse files
committed
wip: print some info on NonLinMPC hessians
1 parent 20f2b2d commit d29673d

File tree

1 file changed

+56
-24
lines changed

1 file changed

+56
-24
lines changed

src/controller/nonlinmpc.jl

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -630,45 +630,31 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630630
end
631631
if !isnothing(hess)
632632
prep_∇²J = prepare_hessian(Jfunc!, hess, Z̃_J, context_J...; strict)
633+
@warn "Here's the objective Hessian sparsity pattern:"
633634
display(sparsity_pattern(prep_∇²J))
634635
else
635636
prep_∇²J = nothing
636637
end
637638
∇J = Vector{JNT}(undef, nZ̃)
638639
∇²J = init_diffmat(JNT, hess, prep_∇²J, nZ̃, nZ̃)
639-
640-
641-
642-
function update_objective!(J, ∇J, Z̃, Z̃arg, hess::Nothing, grad::AbstractADType)
643-
if isdifferent(Z̃arg, Z̃)
644-
Z̃ .= Z̃arg
645-
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J...)
646-
end
647-
end
648-
function update_objective!(J, ∇J, Z̃, Z̃arg, hess::AbstractADType, grad::Nothing)
649-
if isdifferent(Z̃arg, Z̃)
650-
Z̃ .= Z̃arg
651-
J[], _ = value_gradient_and_hessian!(
652-
Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃, context_J...
653-
)
654-
#display(∇J)
655-
#display(∇²J)
656-
#println(∇²J)
657-
end
658-
end
659-
660640
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
661-
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
641+
update_diff_objective!(
642+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643+
)
662644
return J[]::T
663645
end
664646
∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
665647
function (Z̃arg)
666-
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
648+
update_diff_objective!(
649+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650+
)
667651
return ∇J[begin]
668652
end
669653
else # multivariate syntax (see JuMP.@operator doc):
670654
function (∇Jarg::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
671-
update_objective!(J, ∇J, Z̃_J, Z̃arg, hess, grad)
655+
update_diff_objective!(
656+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657+
)
672658
return ∇Jarg .= ∇J
673659
end
674660
end
@@ -784,6 +770,52 @@ function update_predictions!(
784770
return nothing
785771
end
786772

773+
"""
774+
update_diff_objective!(
775+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776+
grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777+
)
778+
779+
TBW
780+
"""
781+
function update_diff_objective!(
782+
Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783+
grad::AbstractADType, hess::Nothing, Jfunc!::F, Z̃arg
784+
) where F <: Function
785+
if isdifferent(Z̃arg, Z̃_J)
786+
Z̃_J .= Z̃arg
787+
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context...)
788+
end
789+
return nothing
790+
end
791+
792+
function update_diff_objective!(
793+
Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794+
grad::Nothing, hess::AbstractADType, Jfunc!::F, Z̃arg
795+
) where F <: Function
796+
if isdifferent(Z̃arg, Z̃_J)
797+
Z̃_J .= Z̃arg
798+
J[], _ = value_gradient_and_hessian!(
799+
Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800+
)
801+
@warn "Here's the current Hessian:"
802+
println(∇²J)
803+
end
804+
return nothing
805+
end
806+
807+
function update_diff_objective!(
808+
Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809+
grad::AbstractADType, hess::AbstractADType, Jfunc!::F, Z̃arg
810+
) where F<: Function
811+
if isdifferent(Z̃arg, Z̃_J)
812+
Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813+
hessian!(Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J...)
814+
J[], _ = value_and_gradient!(Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J...)
815+
end
816+
return nothing
817+
end
818+
787819
@doc raw"""
788820
con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
789821

0 commit comments

Comments
 (0)