Skip to content

Commit 12269ee

Browse files
committed
added: huge performance improvement by caching Jacobians of the constraints
The splatting syntax of `JuMP` forces use to compute the Jacobians of the inequality and equality constraints as a multiple gradients (concatenated in the Jacobian matrices). This means that `jacobian!` function of the AD tools were called redundantly `ng` and `neq` times, for a specific decision vector value `Z̃`. This is wasteful. A caching mechanism was implemented to store the Jacobians of the constraints and reuse them when needed. The performance improvement is about 5-10x faster now on `NonLinMPC` with `NonLinModel`.
1 parent 8ae0694 commit 12269ee

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

src/controller/nonlinmpc.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
563563
function update_simulations!(
564564
Z̃arg::Union{NTuple{N, T}, AbstractVector{T}}, Z̃cache
565565
) where {N, T<:Real}
566-
if any(cache !== arg for (cache, arg) in zip(Z̃cache, Z̃arg)) # new Z̃, update:
566+
if isdifferent(Z̃cache, Z̃arg) # new Z̃, update:
567567
for i in eachindex(Z̃cache)
568568
# Z̃cache .= Z̃arg is type unstable with Z̃arg::NTuple{N, FowardDiff.Dual}
569569
Z̃cache[i] = Z̃arg[i]
@@ -587,13 +587,6 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
587587
end
588588
return nothing
589589
end
590-
# --------------------- normal cache for the AD functions ----------------------------
591-
Z̃arg_vec = Vector{JNT}(undef, nZ̃)
592-
∇J = Vector{JNT}(undef, nZ̃) # gradient of objective J
593-
g_vec = Vector{JNT}(undef, ng)
594-
∇g = Matrix{JNT}(undef, ng, nZ̃) # Jacobian of inequality constraints g
595-
geq_vec = Vector{JNT}(undef, neq)
596-
∇geq = Matrix{JNT}(undef, neq, nZ̃) # Jacobian of equality constraints geq
597590
# --------------------- objective functions -------------------------------------------
598591
function Jfunc(Z̃arg::Vararg{T, N}) where {N, T<:Real}
599592
update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
@@ -608,18 +601,20 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
608601
Ue, Ŷe = get_tmp(Ue_cache, T), get_tmp(Ŷe_cache, T)
609602
U0, Ŷ0 = get_tmp(U0_cache, T), get_tmp(Ŷ0_cache, T)
610603
return obj_nonlinprog!(Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
611-
end
612-
∇J_buffer = GradientBuffer(Jfunc_vec, Z̃arg_vec)
604+
end
605+
Z̃_∇J = fill(myNaN, nZ̃)
606+
∇J = Vector{JNT}(undef, nZ̃) # gradient of objective J
607+
∇J_buffer = GradientBuffer(Jfunc_vec, Z̃_∇J)
613608
∇Jfunc! = if nZ̃ == 1
614609
function (Z̃arg::T) where T<:Real
615-
Z̃arg_vec .= Z̃arg
616-
gradient!(∇J, ∇J_buffer, Z̃arg_vec)
610+
Z̃_∇J .= Z̃arg
611+
gradient!(∇J, ∇J_buffer, Z̃_∇J)
617612
return ∇J[begin] # univariate syntax, see JuMP.@operator doc
618613
end
619614
else
620615
function (∇J::AbstractVector{T}, Z̃arg::Vararg{T, N}) where {N, T<:Real}
621-
Z̃arg_vec .= Z̃arg
622-
gradient!(∇J, ∇J_buffer, Z̃arg_vec)
616+
Z̃_∇J .= Z̃arg
617+
gradient!(∇J, ∇J_buffer, Z̃_∇J)
623618
return ∇J # multivariate syntax, see JuMP.@operator doc
624619
end
625620
end
@@ -638,21 +633,27 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
638633
g .= get_tmp(g_cache, T)
639634
return g
640635
end
641-
∇g_buffer = JacobianBuffer(gfunc_vec!, g_vec, Z̃arg_vec)
642-
∇gfuncs! = Vector{Function}(undef, ng)
636+
Z̃_∇g = fill(myNaN, nZ̃)
637+
g_vec = Vector{JNT}(undef, ng)
638+
∇g = Matrix{JNT}(undef, ng, nZ̃) # Jacobian of inequality constraints g
639+
∇g_buffer = JacobianBuffer(gfunc_vec!, g_vec, Z̃_∇g)
640+
∇gfuncs! = Vector{Function}(undef, ng)
643641
for i in eachindex(∇gfuncs!)
644642
∇gfuncs![i] = if nZ̃ == 1
645643
function (Z̃arg::T) where T<:Real
646-
Z̃arg_vec .= Z̃arg
647-
jacobian!(∇g, ∇g_buffer, g_vec, Z̃arg_vec)
648-
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
644+
if isdifferent(Z̃arg, Z̃_∇g)
645+
Z̃_∇g .= Z̃arg
646+
jacobian!(∇g, ∇g_buffer, g_vec, Z̃_∇g)
647+
end
648+
return ∇g[i, begin] # univariate syntax, see JuMP.@operator doc
649649
end
650650
else
651651
function (∇g_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
652-
Z̃arg_vec .= Z̃arg
653-
jacobian!(∇g, ∇g_buffer, g_vec, Z̃arg_vec)
654-
∇g_i .= @views ∇g[i, :]
655-
return ∇g_i # multivariate syntax, see JuMP.@operator doc
652+
if isdifferent(Z̃arg, Z̃_∇g)
653+
Z̃_∇g .= Z̃arg
654+
jacobian!(∇g, ∇g_buffer, g_vec, Z̃_∇g)
655+
end
656+
return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
656657
end
657658
end
658659
end
@@ -671,17 +672,21 @@ function get_optim_functions(mpc::NonLinMPC, ::JuMP.GenericModel{JNT}) where JNT
671672
geq .= get_tmp(geq_cache, T)
672673
return geq
673674
end
674-
∇geq_buffer = JacobianBuffer(geqfunc_vec!, geq_vec, Z̃arg_vec)
675-
∇geqfuncs! = Vector{Function}(undef, neq)
675+
Z̃_∇geq = fill(myNaN, nZ̃) # NaN to force update at 1st call
676+
geq_vec = Vector{JNT}(undef, neq)
677+
∇geq = Matrix{JNT}(undef, neq, nZ̃) # Jacobian of equality constraints geq
678+
∇geq_buffer = JacobianBuffer(geqfunc_vec!, geq_vec, Z̃_∇geq)
679+
∇geqfuncs! = Vector{Function}(undef, neq)
676680
for i in eachindex(∇geqfuncs!)
677681
# only multivariate syntax, univariate is impossible since nonlinear equality
678682
# constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
679683
∇geqfuncs![i] =
680684
function (∇geq_i, Z̃arg::Vararg{T, N}) where {N, T<:Real}
681-
Z̃arg_vec .= Z̃arg
682-
jacobian!(∇geq, ∇geq_buffer, geq_vec, Z̃arg_vec)
683-
∇geq_i .= @views ∇geq[i, :]
684-
return ∇geq_i
685+
if isdifferent(Z̃arg, Z̃_∇geq)
686+
Z̃_∇geq .= Z̃arg
687+
jacobian!(∇geq, ∇geq_buffer, geq_vec, Z̃_∇geq)
688+
end
689+
return ∇geq_i .= @views ∇geq[i, :]
685690
end
686691
end
687692
return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!

src/estimator/mhe/construct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,7 @@ function get_optim_functions(
13151315
x̄_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), Nc)
13161316
û0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), Nc)
13171317
ŷ0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŷ), Nc)
1318+
# --------------------- update simulation function ------------------------------------
13181319
function update_simulations!(Z̃, Z̃tup::NTuple{N, T}) where {N, T <:Real}
13191320
if any(new !== old for (new, old) in zip(Z̃tup, Z̃)) # new Z̃tup, update predictions:
13201321
Z̃1 = Z̃tup[begin]

src/general.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ function limit_solve_time(optim::GenericModel, Ts)
9191
end
9292
end
9393

94+
"Verify that x and y elements are different using `!==`."
95+
isdifferent(x, y) = any(xi !== yi for (xi, yi) in zip(x, y))
96+
9497
"Generate a block diagonal matrix repeating `n` times the matrix `A`."
9598
repeatdiag(A, n::Int) = kron(I(n), A)
9699

0 commit comments

Comments
 (0)