@@ -630,45 +630,31 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
630
630
end
631
631
if ! isnothing (hess)
632
632
prep_∇²J = prepare_hessian (Jfunc!, hess, Z̃_J, context_J... ; strict)
633
+ @warn " Here's the objective Hessian sparsity pattern:"
633
634
display (sparsity_pattern (prep_∇²J))
634
635
else
635
636
prep_∇²J = nothing
636
637
end
637
638
∇J = Vector {JNT} (undef, nZ̃)
638
639
∇²J = init_diffmat (JNT, hess, prep_∇²J, nZ̃, nZ̃)
639
-
640
-
641
-
642
- function update_objective! (J, ∇J, Z̃, Z̃arg, hess:: Nothing , grad:: AbstractADType )
643
- if isdifferent (Z̃arg, Z̃)
644
- Z̃ .= Z̃arg
645
- J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
646
- end
647
- end
648
- function update_objective! (J, ∇J, Z̃, Z̃arg, hess:: AbstractADType , grad:: Nothing )
649
- if isdifferent (Z̃arg, Z̃)
650
- Z̃ .= Z̃arg
651
- J[], _ = value_gradient_and_hessian! (
652
- Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃, context_J...
653
- )
654
- # display(∇J)
655
- # display(∇²J)
656
- # println(∇²J)
657
- end
658
- end
659
-
660
640
function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
661
- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
641
+ update_diff_objective! (
642
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
643
+ )
662
644
return J[]:: T
663
645
end
664
646
∇Jfunc! = if nZ̃ == 1 # univariate syntax (see JuMP.@operator doc):
665
647
function (Z̃arg)
666
- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
648
+ update_diff_objective! (
649
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
650
+ )
667
651
return ∇J[begin ]
668
652
end
669
653
else # multivariate syntax (see JuMP.@operator doc):
670
654
function (∇Jarg:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
671
- update_objective! (J, ∇J, Z̃_J, Z̃arg, hess, grad)
655
+ update_diff_objective! (
656
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J, grad, hess, Jfunc!, Z̃arg
657
+ )
672
658
return ∇Jarg .= ∇J
673
659
end
674
660
end
@@ -784,6 +770,52 @@ function update_predictions!(
784
770
return nothing
785
771
end
786
772
773
+ """
774
+ update_diff_objective!(
775
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J , context_J,
776
+ grad::AbstractADType, hess::Nothing, Jfunc!, Z̃arg
777
+ )
778
+
779
+ TBW
780
+ """
781
+ function update_diff_objective! (
782
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, _ , context_J,
783
+ grad:: AbstractADType , hess:: Nothing , Jfunc!:: F , Z̃arg
784
+ ) where F <: Function
785
+ if isdifferent (Z̃arg, Z̃_J)
786
+ Z̃_J .= Z̃arg
787
+ J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context... )
788
+ end
789
+ return nothing
790
+ end
791
+
792
+ function update_diff_objective! (
793
+ Z̃_J, J, ∇J, ∇²J, _ , prep_∇²J, context_J,
794
+ grad:: Nothing , hess:: AbstractADType , Jfunc!:: F , Z̃arg
795
+ ) where F <: Function
796
+ if isdifferent (Z̃arg, Z̃_J)
797
+ Z̃_J .= Z̃arg
798
+ J[], _ = value_gradient_and_hessian! (
799
+ Jfunc!, ∇J, ∇²J, prep_∇²J, hess, Z̃_J, context_J...
800
+ )
801
+ @warn " Here's the current Hessian:"
802
+ println (∇²J)
803
+ end
804
+ return nothing
805
+ end
806
+
807
+ function update_diff_objective! (
808
+ Z̃_J, J, ∇J, ∇²J, prep_∇J, prep_∇²J, context_J,
809
+ grad:: AbstractADType , hess:: AbstractADType , Jfunc!:: F , Z̃arg
810
+ ) where F<: Function
811
+ if isdifferent (Z̃arg, Z̃_J)
812
+ Z̃_J .= Z̃arg # inefficient, as warned by validate_backends(), but still possible:
813
+ hessian! (Jfunc!, ∇²J, prep_∇²J, hess, Z̃_J, context_J... )
814
+ J[], _ = value_and_gradient! (Jfunc!, ∇J, prep_∇J, grad, Z̃_J, context_J... )
815
+ end
816
+ return nothing
817
+ end
818
+
787
819
@doc raw """
788
820
con_custom!(gc, mpc::NonLinMPC, Ue, Ŷe, ϵ) -> gc
789
821
0 commit comments