Skip to content

Commit 62d557b

Browse files
authored
Fix DiffRules-based definitions for complex-valued functions (#577)
* Fix DiffRules-based definitions for complex-valued functions * Update tests * Update Project.toml
1 parent e11936c commit 62d557b

File tree

3 files changed

+75
-15
lines changed

3 files changed

+75
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "0.10.26"
3+
version = "0.10.27"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"

src/dual.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,38 @@ macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_b
195195
return esc(defs)
196196
end
197197

198+
# Support complex-valued functions such as `hankelh1`
199+
function dual_definition_retval(::Val{T}, val::Real, deriv::Real, partial::Partials) where {T}
200+
return Dual{T}(val, deriv * partial)
201+
end
202+
function dual_definition_retval(::Val{T}, val::Real, deriv1::Real, partial1::Partials, deriv2::Real, partial2::Partials) where {T}
203+
return Dual{T}(val, _mul_partials(partial1, partial2, deriv1, deriv2))
204+
end
205+
function dual_definition_retval(::Val{T}, val::Complex, deriv::Union{Real,Complex}, partial::Partials) where {T}
206+
reval, imval = reim(val)
207+
if deriv isa Real
208+
p = deriv * partial
209+
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
210+
else
211+
rederiv, imderiv = reim(deriv)
212+
return Complex(Dual{T}(reval, rederiv * partial), Dual{T}(imval, imderiv * partial))
213+
end
214+
end
215+
function dual_definition_retval(::Val{T}, val::Complex, deriv1::Union{Real,Complex}, partial1::Partials, deriv2::Union{Real,Complex}, partial2::Partials) where {T}
216+
reval, imval = reim(val)
217+
if deriv1 isa Real && deriv2 isa Real
218+
p = _mul_partials(partial1, partial2, deriv1, deriv2)
219+
return Complex(Dual{T}(reval, p), Dual{T}(imval, zero(p)))
220+
else
221+
rederiv1, imderiv1 = reim(deriv1)
222+
rederiv2, imderiv2 = reim(deriv2)
223+
return Complex(
224+
Dual{T}(reval, _mul_partials(partial1, partial2, rederiv1, rederiv2)),
225+
Dual{T}(imval, _mul_partials(partial1, partial2, imderiv1, imderiv2)),
226+
)
227+
end
228+
end
229+
198230
function unary_dual_definition(M, f)
199231
FD = ForwardDiff
200232
Mf = M == :Base ? f : :($M.$f)
@@ -206,7 +238,7 @@ function unary_dual_definition(M, f)
206238
@inline function $M.$f(d::$FD.Dual{T}) where T
207239
x = $FD.value(d)
208240
$work
209-
return $FD.Dual{T}(val, deriv * $FD.partials(d))
241+
return $FD.dual_definition_retval(Val{T}(), val, deriv, $FD.partials(d))
210242
end
211243
end
212244
end
@@ -236,17 +268,17 @@ function binary_dual_definition(M, f)
236268
begin
237269
vx, vy = $FD.value(x), $FD.value(y)
238270
$xy_work
239-
return $FD.Dual{Txy}(val, $FD._mul_partials($FD.partials(x), $FD.partials(y), dvx, dvy))
271+
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
240272
end,
241273
begin
242274
vx = $FD.value(x)
243275
$x_work
244-
return $FD.Dual{Tx}(val, dvx * $FD.partials(x))
276+
return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
245277
end,
246278
begin
247279
vy = $FD.value(y)
248280
$y_work
249-
return $FD.Dual{Ty}(val, dvy * $FD.partials(y))
281+
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
250282
end
251283
)
252284
end

test/DualTest.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
440440

441441
if V != Int
442442
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
443-
if f in (:hankelh1, :hankelh1x, :hankelh2, :hankelh2x, :/, :rem2pi)
443+
if f in (:/, :rem2pi)
444444
continue # Skip these rules
445445
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
446446
continue # Skip rules for methods not defined in the current scope
@@ -457,9 +457,20 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
457457
end
458458
@eval begin
459459
x = rand() + $modifier
460-
dx = $M.$f(Dual{TestTag()}(x, one(x)))
461-
@test value(dx) == $M.$f(x)
462-
@test partials(dx, 1) == $deriv
460+
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)))
461+
actualval = $M.$f(x)
462+
@assert actualval isa Real || actualval isa Complex
463+
if actualval isa Real
464+
@test dx isa Dual{TestTag()}
465+
@test value(dx) == actualval
466+
@test partials(dx, 1) == $deriv
467+
else
468+
@test dx isa Complex{<:Dual{TestTag()}}
469+
@test value(real(dx)) == real(actualval)
470+
@test value(imag(dx)) == imag(actualval)
471+
@test partials(real(dx), 1) == real($deriv)
472+
@test partials(imag(dx), 1) == imag($deriv)
473+
end
463474
end
464475
elseif arity == 2
465476
derivs = DiffRules.diffrule(M, f, :x, :y)
@@ -472,14 +483,31 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
472483
end
473484
@eval begin
474485
x, y = $x, $y
475-
dx = $M.$f(Dual{TestTag()}(x, one(x)), y)
476-
dy = $M.$f(x, Dual{TestTag()}(y, one(y)))
486+
dx = @inferred $M.$f(Dual{TestTag()}(x, one(x)), y)
487+
dy = @inferred $M.$f(x, Dual{TestTag()}(y, one(y)))
477488
actualdx = $(derivs[1])
478489
actualdy = $(derivs[2])
479-
@test value(dx) == $M.$f(x, y)
480-
@test value(dy) == value(dx)
481-
@test partials(dx, 1) actualdx nans=true
482-
@test partials(dy, 1) actualdy nans=true
490+
actualval = $M.$f(x, y)
491+
@assert actualval isa Real || actualval isa Complex
492+
if actualval isa Real
493+
@test dx isa Dual{TestTag()}
494+
@test dy isa Dual{TestTag()}
495+
@test value(dx) == actualval
496+
@test value(dy) == actualval
497+
@test partials(dx, 1) actualdx nans=true
498+
@test partials(dy, 1) actualdy nans=true
499+
else
500+
@test dx isa Complex{<:Dual{TestTag()}}
501+
@test dy isa Complex{<:Dual{TestTag()}}
502+
@test real(value(dx)) == real(actualval)
503+
@test real(value(dy)) == real(actualval)
504+
@test imag(value(dx)) == imag(actualval)
505+
@test imag(value(dy)) == imag(actualval)
506+
@test partials(real(dx), 1) real(actualdx) nans=true
507+
@test partials(real(dy), 1) real(actualdy) nans=true
508+
@test partials(imag(dx), 1) imag(actualdx) nans=true
509+
@test partials(imag(dy), 1) imag(actualdy) nans=true
510+
end
483511
end
484512
end
485513
end

0 commit comments

Comments
 (0)