Skip to content

Commit 8eaba05

Browse files
jClugstordevmotion
andauthored
Fix exponentiation for NaNMath.pow (#717)
* fix NaNMath exponentiation * reuse code * fix * add tests * Update src/dual.jl Co-authored-by: David Widmann <[email protected]> * import NaNMath * oops, no begin * Update test/GradientTest.jl Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]>
1 parent 7e9d778 commit 8eaba05

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

src/dual.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ end
552552
# exponentiation #
553553
#----------------#
554554

555-
for f in (:(Base.:^), :(NaNMath.pow))
555+
for (f, log) in ((:(Base.:^), :(Base.log)), (:(NaNMath.pow), :(NaNMath.log)))
556556
@eval begin
557557
@define_binary_dual_op(
558558
$f,
@@ -565,7 +565,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
565565
elseif iszero(vx) && vy > 0
566566
logval = zero(vx)
567567
else
568-
logval = expv * log(vx)
568+
logval = expv * ($log)(vx)
569569
end
570570
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
571571
return Dual{Txy}(expv, new_partials)
@@ -583,7 +583,7 @@ for f in (:(Base.:^), :(NaNMath.pow))
583583
begin
584584
v = value(y)
585585
expv = ($f)(x, v)
586-
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
586+
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*($log)(x)
587587
return Dual{Ty}(expv, deriv * partials(y))
588588
end
589589
)

test/DerivativeTest.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DerivativeTest
22

33
import Calculus
4+
import NaNMath
45

56
using Test
67
using Random
@@ -93,6 +94,17 @@ end
9394
@test (x -> ForwardDiff.derivative(y -> x^y, 1.5))(0.0) === 0.0
9495
end
9596

97+
@testset "exponentiation with NaNMath" begin
98+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(NaN, x), 1.0))
99+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,NaN), 1.0))
100+
@test !isnan(ForwardDiff.derivative(x -> NaNMath.pow(1.0, x),1.0))
101+
@test isnan(ForwardDiff.derivative(x -> NaNMath.pow(x,0.5), -1.0))
102+
103+
@test isnan(ForwardDiff.derivative(x -> x^NaN, 2.0))
104+
@test ForwardDiff.derivative(x -> x^2.0,2.0) == 4.0
105+
@test_throws DomainError ForwardDiff.derivative(x -> x^0.5, -1.0)
106+
end
107+
96108
@testset "dimension error for derivative" begin
97109
@test_throws DimensionMismatch ForwardDiff.derivative(sum, fill(2pi, 3))
98110
end

test/GradientTest.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module GradientTest
22

33
import Calculus
4+
import NaNMath
45

56
using Test
67
using LinearAlgebra
@@ -200,6 +201,16 @@ end
200201
@test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
201202
end
202203

204+
@testset "gradient for exponential with NaNMath" begin
205+
@test isnan(ForwardDiff.gradient(x -> NaNMath.pow(x[1],x[1]), [NaN, 1.0])[1])
206+
@test ForwardDiff.gradient(x -> NaNMath.pow(x[1], x[2]), [1.0, 1.0]) == [1.0, 0.0]
207+
@test isnan(ForwardDiff.gradient((x) -> NaNMath.pow(x[1], x[2]), [-1.0, 0.5])[1])
208+
209+
@test isnan(ForwardDiff.gradient(x -> x[1]^x[2], [NaN, 1.0])[1])
210+
@test ForwardDiff.gradient(x -> x[1]^x[2], [1.0, 1.0]) == [1.0, 0.0]
211+
@test_throws DomainError ForwardDiff.gradient(x -> x[1]^x[2], [-1.0, 0.5])
212+
end
213+
203214
@testset "branches in mul!" begin
204215
a, b = rand(3,3), rand(3,3)
205216

0 commit comments

Comments
 (0)