Skip to content

Commit 1592fe9

Browse files
Merge pull request #628 from JuliaDiff/kc/extension_staticarrays
make StaticArrays dependency into an extension on v0.11-dev
2 parents 4b143a1 + df7e6e0 commit 1592fe9

File tree

8 files changed

+151
-125
lines changed

8 files changed

+151
-125
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1717

18+
[weakdeps]
19+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
20+
21+
[extensions]
22+
ForwardDiffStaticArraysExt = "StaticArrays"
23+
1824
[compat]
1925
Calculus = "0.5"
2026
CommonSubexpressions = "0.3"
@@ -33,7 +39,8 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3339
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
3440
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3541
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
42+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3643
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3744

3845
[targets]
39-
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"]
46+
test = ["Calculus", "DiffTests", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
module ForwardDiffStaticArraysExt
2+
3+
using ForwardDiff, StaticArrays
4+
using ForwardDiff.LinearAlgebra
5+
using ForwardDiff.DiffResults
6+
using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk,
7+
gradient, hessian, jacobian, gradient!, hessian!, jacobian!,
8+
extract_gradient!, extract_jacobian!, extract_value!,
9+
vector_mode_gradient, vector_mode_gradient!,
10+
vector_mode_jacobian, vector_mode_jacobian!, valtype, value, _lyap_div!
11+
using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult
12+
13+
@generated function dualize(::Type{T}, x::StaticArray) where T
14+
N = length(x)
15+
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
16+
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N})
17+
return quote
18+
chunk = Chunk{$N}()
19+
$(Expr(:meta, :inline))
20+
return $V($(dx))
21+
end
22+
end
23+
24+
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
25+
26+
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
27+
λ,Q = eigen(Symmetric(value.(parent(A))))
28+
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
29+
Dual{Tg}.(λ, tuple.(parts...))
30+
end
31+
32+
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
33+
λ = eigvals(A)
34+
_,Q = eigen(Symmetric(value.(parent(A))))
35+
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
36+
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
37+
end
38+
39+
# Gradient
40+
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
41+
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
42+
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
43+
44+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
45+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
46+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
47+
48+
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
49+
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
50+
return quote
51+
$(Expr(:meta, :inline))
52+
V = StaticArrays.similar_type(S, valtype($y))
53+
return V($result)
54+
end
55+
end
56+
57+
@inline function ForwardDiff.vector_mode_gradient(f, x::StaticArray)
58+
T = typeof(Tag(f, eltype(x)))
59+
return extract_gradient(T, static_dual_eval(T, f, x), x)
60+
end
61+
62+
@inline function ForwardDiff.vector_mode_gradient!(result, f, x::StaticArray)
63+
T = typeof(Tag(f, eltype(x)))
64+
return extract_gradient!(T, result, static_dual_eval(T, f, x))
65+
end
66+
67+
# Jacobian
68+
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
69+
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
70+
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
71+
72+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
73+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
74+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
75+
76+
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
77+
M, N = length(ydual), length(x)
78+
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
79+
return quote
80+
$(Expr(:meta, :inline))
81+
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
82+
return V($result)
83+
end
84+
end
85+
86+
@inline function ForwardDiff.vector_mode_jacobian(f, x::StaticArray)
87+
T = typeof(Tag(f, eltype(x)))
88+
return extract_jacobian(T, static_dual_eval(T, f, x), x)
89+
end
90+
91+
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
92+
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
93+
return extract_jacobian!(T, result, ydual, length(x))
94+
end
95+
96+
@inline function ForwardDiff.vector_mode_jacobian!(result, f, x::StaticArray)
97+
T = typeof(Tag(f, eltype(x)))
98+
ydual = static_dual_eval(T, f, x)
99+
result = extract_jacobian!(T, result, ydual, length(x))
100+
result = extract_value!(T, result, ydual)
101+
return result
102+
end
103+
104+
@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
105+
T = typeof(Tag(f, eltype(x)))
106+
ydual = static_dual_eval(T, f, x)
107+
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
108+
result = DiffResults.value!(d -> value(T,d), result, ydual)
109+
return result
110+
end
111+
112+
# Hessian
113+
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
114+
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
115+
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
116+
117+
ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
118+
119+
ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
120+
121+
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
122+
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
123+
124+
function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
125+
T = typeof(Tag(f, eltype(x)))
126+
d1 = dualize(T, x)
127+
d2 = dualize(T, d1)
128+
fd2 = f(d2)
129+
val = value(T,value(T,fd2))
130+
grad = extract_gradient(T,value(T,fd2), x)
131+
hess = extract_jacobian(T,partials(T,fd2), x)
132+
result = DiffResults.hessian!(result, hess)
133+
result = DiffResults.gradient!(result, grad)
134+
result = DiffResults.value!(result, val)
135+
return result
136+
end
137+
138+
end

src/ForwardDiff.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module ForwardDiff
22

33
using DiffRules, DiffResults
4-
using DiffResults: DiffResult, MutableDiffResult, ImmutableDiffResult
5-
using StaticArrays
4+
using DiffResults: DiffResult, MutableDiffResult
65
using Preferences
76
using Random
87
using LinearAlgebra
@@ -23,6 +22,10 @@ include("gradient.jl")
2322
include("jacobian.jl")
2423
include("hessian.jl")
2524

25+
if !isdefined(Base, :get_extension)
26+
include("../ext/ForwardDiffStaticArraysExt.jl")
27+
end
28+
2629
export DiffResults
2730

2831
end # module

src/apiutils.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,6 @@ end
1818
# vector mode function evaluation #
1919
###################################
2020

21-
@generated function dualize(::Type{T}, x::StaticArray) where T
22-
N = length(x)
23-
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
24-
V = StaticArrays.similar_type(x, Dual{T,eltype(x),N})
25-
return quote
26-
chunk = Chunk{$N}()
27-
$(Expr(:meta, :inline))
28-
return $V($(dx))
29-
end
30-
end
31-
32-
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
33-
3421
function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F}
3522
xdual = cfg.duals
3623
seed!(xdual, x, cfg.seeds)

src/dual.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,12 +726,6 @@ function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N
726726
Dual{Tg}.(λ, tuple.(parts...))
727727
end
728728

729-
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
730-
λ,Q = eigen(Symmetric(value.(parent(A))))
731-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
732-
Dual{Tg}.(λ, tuple.(parts...))
733-
end
734-
735729
function LinearAlgebra.eigvals(A::Hermitian{<:Complex{<:Dual{Tg,T,N}}}) where {Tg,T<:Real,N}
736730
λ,Q = eigen(Hermitian(value.(real.(parent(A))) .+ im .* value.(imag.(parent(A)))))
737731
parts = ntuple(j -> diag(real.(Q' * (getindex.(partials.(real.(A)) .+ im .* partials.(imag.(A)), j)) * Q)), N)
@@ -761,13 +755,6 @@ function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
761755
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
762756
end
763757

764-
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
765-
λ = eigvals(A)
766-
_,Q = eigen(Symmetric(value.(parent(A))))
767-
parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
768-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
769-
end
770-
771758
function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
772759
λ = eigvals(A)
773760
_,Q = eigen(SymTridiagonal(value.(parent(A))))

src/gradient.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,12 @@ function gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArr
4343
return result
4444
end
4545

46-
@inline gradient(f::F, x::StaticArray) where F = vector_mode_gradient(f, x)
47-
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig) where F = gradient(f, x)
48-
@inline gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient(f, x)
49-
50-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_gradient!(result, f, x)
51-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where F = gradient!(result, f, x)
52-
@inline gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient!(result, f, x)
53-
5446
gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
5547

5648
#####################
5749
# result extraction #
5850
#####################
5951

60-
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
61-
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
62-
return quote
63-
$(Expr(:meta, :inline))
64-
V = StaticArrays.similar_type(S, valtype($y))
65-
return V($result)
66-
end
67-
end
68-
6952
function extract_gradient!(::Type{T}, result::DiffResult, y::Real) where {T}
7053
result = DiffResults.value!(result, y)
7154
grad = DiffResults.gradient(result)
@@ -117,16 +100,6 @@ function vector_mode_gradient!(result, f::F, x, cfg::GradientConfig{T}) where {T
117100
return result
118101
end
119102

120-
@inline function vector_mode_gradient(f, x::StaticArray)
121-
T = typeof(Tag(f, eltype(x)))
122-
return extract_gradient(T, static_dual_eval(T, f, x), x)
123-
end
124-
125-
@inline function vector_mode_gradient!(result, f, x::StaticArray)
126-
T = typeof(Tag(f, eltype(x)))
127-
return extract_gradient!(T, result, static_dual_eval(T, f, x))
128-
end
129-
130103
##############
131104
# chunk mode #
132105
##############

src/hessian.jl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,3 @@ function hessian!(result::DiffResult, f::F, x::AbstractArray, cfg::HessianConfig
6969
jacobian!(DiffResults.hessian(result), ∇f!, DiffResults.gradient(result), x, cfg.jacobian_config, Val{false}())
7070
return ∇f!.result
7171
end
72-
73-
hessian(f::F, x::StaticArray) where F = jacobian(y -> gradient(f, y), x)
74-
hessian(f::F, x::StaticArray, cfg::HessianConfig) where F = hessian(f, x)
75-
hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian(f, x)
76-
77-
hessian!(result::AbstractArray, f::F, x::StaticArray) where F = jacobian!(result, y -> gradient(f, y), x)
78-
79-
hessian!(result::MutableDiffResult, f::F, x::StaticArray) where F = hessian!(result, f, x, HessianConfig(f, result, x))
80-
81-
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where F = hessian!(result, f, x)
82-
hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian!(result, f, x)
83-
84-
function hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where F
85-
T = typeof(Tag(f, eltype(x)))
86-
d1 = dualize(T, x)
87-
d2 = dualize(T, d1)
88-
fd2 = f(d2)
89-
val = value(T,value(T,fd2))
90-
grad = extract_gradient(T,value(T,fd2), x)
91-
hess = extract_jacobian(T,partials(T,fd2), x)
92-
result = DiffResults.hessian!(result, hess)
93-
result = DiffResults.gradient!(result, grad)
94-
result = DiffResults.value!(result, val)
95-
return result
96-
end

src/jacobian.jl

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,35 +86,12 @@ function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractAr
8686
return result
8787
end
8888

89-
@inline jacobian(f::F, x::StaticArray) where F = vector_mode_jacobian(f, x)
90-
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian(f, x)
91-
@inline jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian(f, x)
92-
93-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_jacobian!(result, f, x)
94-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian!(result, f, x)
95-
@inline jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian!(result, f, x)
96-
9789
jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is an array. Perhaps you meant derivative(f, x)?"))
9890

9991
#####################
10092
# result extraction #
10193
#####################
10294

103-
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
104-
M, N = length(ydual), length(x)
105-
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
106-
return quote
107-
$(Expr(:meta, :inline))
108-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
109-
return V($result)
110-
end
111-
end
112-
113-
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
114-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
115-
return extract_jacobian!(T, result, ydual, length(x))
116-
end
117-
11895
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
11996
out_reshaped = reshape(result, length(ydual), n)
12097
ydual_reshaped = vec(ydual)
@@ -184,27 +161,6 @@ function vector_mode_jacobian!(result, f!::F, y, x, cfg::JacobianConfig{T}) wher
184161
return result
185162
end
186163

187-
@inline function vector_mode_jacobian(f, x::StaticArray)
188-
T = typeof(Tag(f, eltype(x)))
189-
return extract_jacobian(T, static_dual_eval(T, f, x), x)
190-
end
191-
192-
@inline function vector_mode_jacobian!(result, f, x::StaticArray)
193-
T = typeof(Tag(f, eltype(x)))
194-
ydual = static_dual_eval(T, f, x)
195-
result = extract_jacobian!(T, result, ydual, length(x))
196-
result = extract_value!(T, result, ydual)
197-
return result
198-
end
199-
200-
@inline function vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
201-
T = typeof(Tag(f, eltype(x)))
202-
ydual = static_dual_eval(T, f, x)
203-
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
204-
result = DiffResults.value!(d -> value(T,d), result, ydual)
205-
return result
206-
end
207-
208164
const JACOBIAN_ERROR = DimensionMismatch("jacobian(f, x) expects that f(x) is an array. Perhaps you meant gradient(f, x)?")
209165

210166
# chunk mode #

0 commit comments

Comments
 (0)