Skip to content

Commit ed0bc33

Browse files
Support complex output in derivative (#583)
1 parent 62d557b commit ed0bc33

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/derivative.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Set `check` to `Val{false}()` to disable tag checking. This can lead to perturba
7070
end
7171

7272
derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number. Perhaps you meant gradient(f, x)?"))
73+
derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input."))
7374

7475
#####################
7576
# result extraction #
@@ -78,9 +79,13 @@ derivative(f, x::AbstractArray) = throw(DimensionMismatch("derivative(f, x) expe
7879
# non-mutating #
7980
#--------------#
8081

81-
@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1)
8282
@inline extract_derivative(::Type{T}, y::Real) where {T} = zero(y)
83+
@inline extract_derivative(::Type{T}, y::Complex) where {T} = zero(y)
84+
@inline extract_derivative(::Type{T}, y::Dual) where {T} = partials(T, y, 1)
8385
@inline extract_derivative(::Type{T}, y::AbstractArray) where {T} = map(d -> extract_derivative(T,d), y)
86+
@inline function extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: Dual}
87+
complex(partials(T, real(y), 1), partials(T, imag(y), 1))
88+
end
8489

8590
# mutating #
8691
#----------#

test/DerivativeTest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,8 @@ end
100100
@test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3))
101101
end
102102

103+
@testset "complex output" begin
104+
@test ForwardDiff.derivative(x -> (1+im)*x, 0) == (1+im)
105+
end
106+
103107
end # module

0 commit comments

Comments
 (0)