Skip to content

Commit 61c2741

Browse files
authored
Methods with and without extras (#313)
* Methods with and without extras * Typos * Fix * SecondOrder for hvp * Polyester * Typo * Typo * Add correctness tests * Extras type * Typo * Other typo * Typo * Hessian doc
1 parent 0b7a2f9 commit 61c2741

File tree

11 files changed

+772
-765
lines changed

11 files changed

+772
-765
lines changed

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using DifferentiationInterface:
1111
NoGradientExtras,
1212
NoHessianExtras,
1313
NoJacobianExtras,
14-
PushforwardExtras
14+
PushforwardExtras,
15+
PushforwardDerivativeExtras
1516
using DocStringExtensions
1617
using LinearAlgebra: mul!
1718
using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian!

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,32 @@ end
3131

3232
## Derivative
3333

34-
function DI.prepare_derivative(f, backend::AutoPolyesterForwardDiff, x)
34+
function DI.prepare_derivative(
35+
f, backend::AutoPolyesterForwardDiff, x
36+
)::PushforwardDerivativeExtras
3537
return DI.prepare_derivative(f, single_threaded(backend), x)
3638
end
3739

3840
function DI.value_and_derivative(
39-
f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras
41+
f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras
4042
)
4143
return DI.value_and_derivative(f, single_threaded(backend), x, extras)
4244
end
4345

4446
function DI.value_and_derivative!(
45-
f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras
47+
f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras
4648
)
4749
return DI.value_and_derivative!(f, der, single_threaded(backend), x, extras)
4850
end
4951

50-
function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras)
52+
function DI.derivative(
53+
f, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras
54+
)
5155
return DI.derivative(f, single_threaded(backend), x, extras)
5256
end
5357

5458
function DI.derivative!(
55-
f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras
59+
f, der, backend::AutoPolyesterForwardDiff, x, extras::PushforwardDerivativeExtras
5660
)
5761
return DI.derivative!(f, der, single_threaded(backend), x, extras)
5862
end

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -72,84 +72,86 @@ end
7272

7373
## One argument
7474

75+
function value_and_derivative(f::F, backend::AbstractADType, x) where {F}
76+
return value_and_derivative(f, backend, x, prepare_derivative(f, backend, x))
77+
end
78+
79+
function value_and_derivative!(f::F, der, backend::AbstractADType, x) where {F}
80+
return value_and_derivative!(f, der, backend, x, prepare_derivative(f, backend, x))
81+
end
82+
83+
function derivative(f::F, backend::AbstractADType, x) where {F}
84+
return derivative(f, backend, x, prepare_derivative(f, backend, x))
85+
end
86+
87+
function derivative!(f::F, der, backend::AbstractADType, x) where {F}
88+
return derivative!(f, der, backend, x, prepare_derivative(f, backend, x))
89+
end
90+
7591
function value_and_derivative(
76-
f::F,
77-
backend::AbstractADType,
78-
x,
79-
extras::DerivativeExtras=prepare_derivative(f, backend, x),
92+
f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
8093
) where {F}
8194
return value_and_pushforward(f, backend, x, one(x), extras.pushforward_extras)
8295
end
8396

8497
function value_and_derivative!(
85-
f::F,
86-
der,
87-
backend::AbstractADType,
88-
x,
89-
extras::DerivativeExtras=prepare_derivative(f, backend, x),
98+
f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
9099
) where {F}
91100
return value_and_pushforward!(f, der, backend, x, one(x), extras.pushforward_extras)
92101
end
93102

94103
function derivative(
95-
f::F,
96-
backend::AbstractADType,
97-
x,
98-
extras::DerivativeExtras=prepare_derivative(f, backend, x),
104+
f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
99105
) where {F}
100106
return pushforward(f, backend, x, one(x), extras.pushforward_extras)
101107
end
102108

103109
function derivative!(
104-
f::F,
105-
der,
106-
backend::AbstractADType,
107-
x,
108-
extras::DerivativeExtras=prepare_derivative(f, backend, x),
110+
f::F, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
109111
) where {F}
110112
return pushforward!(f, der, backend, x, one(x), extras.pushforward_extras)
111113
end
112114

113115
## Two arguments
114116

117+
function value_and_derivative(f!::F, y, backend::AbstractADType, x) where {F}
118+
return value_and_derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x))
119+
end
120+
121+
function value_and_derivative!(f!::F, y, der, backend::AbstractADType, x) where {F}
122+
return value_and_derivative!(
123+
f!, y, der, backend, x, prepare_derivative(f!, y, backend, x)
124+
)
125+
end
126+
127+
function derivative(f!::F, y, backend::AbstractADType, x) where {F}
128+
return derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x))
129+
end
130+
131+
function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F}
132+
return derivative!(f!, y, der, backend, x, prepare_derivative(f!, y, backend, x))
133+
end
134+
115135
function value_and_derivative(
116-
f!::F,
117-
y,
118-
backend::AbstractADType,
119-
x,
120-
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
136+
f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
121137
) where {F}
122138
return value_and_pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
123139
end
124140

125141
function value_and_derivative!(
126-
f!::F,
127-
y,
128-
der,
129-
backend::AbstractADType,
130-
x,
131-
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
142+
f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
132143
) where {F}
133144
return value_and_pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
134145
end
135146

136147
function derivative(
137-
f!::F,
138-
y,
139-
backend::AbstractADType,
140-
x,
141-
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
148+
f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
142149
) where {F}
143150
return pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
144151
end
145152

146153
function derivative!(
147-
f!::F,
148-
y,
149-
der,
150-
backend::AbstractADType,
151-
x,
152-
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
154+
f!::F, y, der, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
153155
) where {F}
154156
return pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
155157
end

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,34 +62,42 @@ end
6262

6363
## One argument
6464

65+
function value_and_gradient(f::F, backend::AbstractADType, x) where {F}
66+
return value_and_gradient(f, backend, x, prepare_gradient(f, backend, x))
67+
end
68+
69+
function value_and_gradient!(f::F, der, backend::AbstractADType, x) where {F}
70+
return value_and_gradient!(f, der, backend, x, prepare_gradient(f, backend, x))
71+
end
72+
73+
function gradient(f::F, backend::AbstractADType, x) where {F}
74+
return gradient(f, backend, x, prepare_gradient(f, backend, x))
75+
end
76+
77+
function gradient!(f::F, der, backend::AbstractADType, x) where {F}
78+
return gradient!(f, der, backend, x, prepare_gradient(f, backend, x))
79+
end
80+
6581
function value_and_gradient(
66-
f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
82+
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
6783
) where {F}
6884
return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
6985
end
7086

7187
function value_and_gradient!(
72-
f::F,
73-
grad,
74-
backend::AbstractADType,
75-
x,
76-
extras::GradientExtras=prepare_gradient(f, backend, x),
88+
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
7789
) where {F}
7890
return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
7991
end
8092

8193
function gradient(
82-
f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
94+
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
8395
) where {F}
8496
return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
8597
end
8698

8799
function gradient!(
88-
f::F,
89-
grad,
90-
backend::AbstractADType,
91-
x,
92-
extras::GradientExtras=prepare_gradient(f, backend, x),
100+
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
93101
) where {F}
94102
return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
95103
end

0 commit comments

Comments
 (0)