Skip to content

Commit 37c1d50

Browse files
authored
Specialize on functions in StaticArrays extension (#721)
1 parent 3acc8a6 commit 37c1d50

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult
2121
end
2222
end
2323

24-
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
24+
@inline static_dual_eval(::Type{T}, f::F, x::StaticArray) where {T,F} = f(dualize(T, x))
2525

2626
# To fix method ambiguity issues:
2727
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
@@ -35,13 +35,13 @@ end
3535
ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div!(A, λ)
3636

3737
# Gradient
38-
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
39-
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
40-
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
38+
@inline ForwardDiff.gradient(f::F, x::StaticArray) where {F} = vector_mode_gradient(f, x)
39+
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient(f, x)
40+
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient(f, x)
4141

42-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
43-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
44-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
42+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_gradient!(result, f, x)
43+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient!(result, f, x)
44+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient!(result, f, x)
4545

4646
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
4747
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
@@ -52,24 +52,24 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
5252
end
5353
end
5454

55-
@inline function ForwardDiff.vector_mode_gradient(f, x::StaticArray)
55+
@inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F}
5656
T = typeof(Tag(f, eltype(x)))
5757
return extract_gradient(T, static_dual_eval(T, f, x), x)
5858
end
5959

60-
@inline function ForwardDiff.vector_mode_gradient!(result, f, x::StaticArray)
60+
@inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F}
6161
T = typeof(Tag(f, eltype(x)))
6262
return extract_gradient!(T, result, static_dual_eval(T, f, x))
6363
end
6464

6565
# Jacobian
66-
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
67-
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
68-
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
66+
@inline ForwardDiff.jacobian(f::F, x::StaticArray) where {F} = vector_mode_jacobian(f, x)
67+
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian(f, x)
68+
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian(f, x)
6969

70-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
71-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
72-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
70+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_jacobian!(result, f, x)
71+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x)
72+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x)
7373

7474
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
7575
M, N = length(ydual), length(x)
@@ -81,7 +81,7 @@ end
8181
end
8282
end
8383

84-
@inline function ForwardDiff.vector_mode_jacobian(f, x::StaticArray)
84+
@inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F}
8585
T = typeof(Tag(f, eltype(x)))
8686
return extract_jacobian(T, static_dual_eval(T, f, x), x)
8787
end
@@ -91,15 +91,15 @@ function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where
9191
return extract_jacobian!(T, result, ydual, length(x))
9292
end
9393

94-
@inline function ForwardDiff.vector_mode_jacobian!(result, f, x::StaticArray)
94+
@inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F}
9595
T = typeof(Tag(f, eltype(x)))
9696
ydual = static_dual_eval(T, f, x)
9797
result = extract_jacobian!(T, result, ydual, length(x))
9898
result = extract_value!(T, result, ydual)
9999
return result
100100
end
101101

102-
@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
102+
@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
103103
T = typeof(Tag(f, eltype(x)))
104104
ydual = static_dual_eval(T, f, x)
105105
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
@@ -108,18 +108,18 @@ end
108108
end
109109

110110
# Hessian
111-
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
112-
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
113-
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
111+
ForwardDiff.hessian(f::F, x::StaticArray) where {F} = jacobian(y -> gradient(f, y), x)
112+
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian(f, x)
113+
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian(f, x)
114114

115-
ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
115+
ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where {F} = jacobian!(result, y -> gradient(f, y), x)
116116

117-
ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
117+
ForwardDiff.hessian!(result::MutableDiffResult, f::F, x::StaticArray) where {F} = hessian!(result, f, x, HessianConfig(f, result, x))
118118

119-
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
120-
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
119+
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian!(result, f, x)
120+
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian!(result, f, x)
121121

122-
function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
122+
function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
123123
T = typeof(Tag(f, eltype(x)))
124124
d1 = dualize(T, x)
125125
d2 = dualize(T, d1)

0 commit comments

Comments
 (0)