Skip to content

Commit 6a6443b

Browse files
authored
Change == to ignore measure-zero branches (#481)
1 parent 61e4dd4 commit 6a6443b

File tree

9 files changed

+196
-64
lines changed

9 files changed

+196
-64
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "0.10.32"
3+
version = "0.10.33"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"

src/dual.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,40 @@ for pred in UNARY_PREDICATES
384384
@eval Base.$(pred)(d::Dual) = $(pred)(value(d))
385385
end
386386

387-
for pred in BINARY_PREDICATES
387+
# Before PR#481 this loop ran over this list:
388+
# BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
389+
# Not a minimal set, as Base defines some in terms of others.
390+
for pred in [:isless, :<, :>, :(<=), :(>=)]
388391
@eval begin
389392
@define_binary_dual_op(
390393
Base.$(pred),
391394
$(pred)(value(x), value(y)),
392395
$(pred)(value(x), y),
393-
$(pred)(x, value(y))
396+
$(pred)(x, value(y)),
394397
)
395398
end
396399
end
397400

401+
Base.iszero(x::Dual) = iszero(value(x)) && iszero(partials(x)) # shortcut, equivalent to x == zero(x)
402+
403+
for pred in [:isequal, :(==)]
404+
@eval begin
405+
@define_binary_dual_op(
406+
Base.$(pred),
407+
$(pred)(value(x), value(y)) && $(pred)(partials(x), partials(y)),
408+
$(pred)(value(x), y) && iszero(partials(x)),
409+
$(pred)(x, value(y)) && iszero(partials(y)),
410+
)
411+
end
412+
end
413+
414+
@define_binary_dual_op(
415+
Base.:(!=),
416+
(!=)(value(x), value(y)) || (!=)(partials(x), partials(y)),
417+
(!=)(value(x), y) || !iszero(partials(x)),
418+
(!=)(x, value(y)) || !iszero(partials(y)),
419+
)
420+
398421
########################
399422
# Promotion/Conversion #
400423
########################

src/prelude.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Rou
1010

1111
const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
1212

13-
const BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
13+
const DEFAULT_CHUNK_THRESHOLD = 12
1414

1515
struct Chunk{N} end
1616

test/DualTest.jl

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false
3030
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
3131
ForwardDiff.:(::Type{OuterTestTag}, ::Type{TestTag}) = false
3232

33-
for N in (0,3), M in (0,4), V in (Int, Float32)
33+
@testset "Dual{Z,$V,$N} and Dual{Z,Dual{Z,$V,$M},$N}" for N in (0,3), M in (0,4), V in (Int, Float32)
3434
println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}")
3535

3636
PARTIALS = Partials{N,V}(ntuple(n -> intrand(V), N))
@@ -44,6 +44,13 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
4444
PARTIALS3 = Partials{N,V}(ntuple(n -> intrand(V), N))
4545
PRIMAL3 = intrand(V)
4646
FDNUM3 = Dual{TestTag()}(PRIMAL3, PARTIALS3)
47+
48+
if !allunique([PRIMAL, PRIMAL2, PRIMAL3])
49+
@info "testing with non-unique primals" PRIMAL PRIMAL2 PRIMAL3
50+
end
51+
if N > 0 && !allunique([PARTIALS, PARTIALS2, PARTIALS3])
52+
@info "testing with non-unique partials" PARTIALS PARTIALS2 PARTIALS3
53+
end
4754

4855
M_PARTIALS = Partials{M,V}(ntuple(m -> intrand(V), M))
4956
NESTED_PARTIALS = convert(Partials{N,Dual{TestTag(),V,M}}, PARTIALS)
@@ -231,15 +238,27 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
231238
@test ForwardDiff.isconstant(one(NESTED_FDNUM))
232239
@test ForwardDiff.isconstant(NESTED_FDNUM) == (N == 0)
233240

234-
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2))
235-
@test isequal(PRIMAL, PRIMAL2) == isequal(FDNUM, FDNUM2)
236-
237-
@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2))
238-
@test isequal(PRIMAL, PRIMAL2) == isequal(NESTED_FDNUM, NESTED_FDNUM2)
239-
240-
@test FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)
241-
@test (PRIMAL == PRIMAL2) == (FDNUM == FDNUM2)
242-
@test (PRIMAL == PRIMAL2) == (NESTED_FDNUM == NESTED_FDNUM2)
241+
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
242+
# and FDNUM2 has everything with a 2, and all random numbers nonzero.
243+
# M is the length of M_PARTIALS, which affects:
244+
# NESTED_FDNUM = Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)
245+
246+
@test (FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
247+
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
248+
@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2)) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))
249+
250+
if PRIMAL == PRIMAL2
251+
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
252+
@test isequal(FDNUM, FDNUM2) == (PARTIALS == PARTIALS2)
253+
254+
@test (FDNUM == FDNUM2) == (PARTIALS == PARTIALS2)
255+
@test (NESTED_FDNUM == NESTED_FDNUM2) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))
256+
else
257+
@test !isequal(FDNUM, FDNUM2)
258+
259+
@test FDNUM != FDNUM2
260+
@test NESTED_FDNUM != NESTED_FDNUM2
261+
end
243262

244263
@test isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(2, PARTIALS2))
245264
@test !(isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(1, PARTIALS2)))
@@ -344,7 +363,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
344363
@test typeof(WIDE_NESTED_FDNUM) === Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N}
345364

346365
@test value(WIDE_FDNUM) == PRIMAL
347-
@test value(WIDE_NESTED_FDNUM) == PRIMAL
366+
@test (value(WIDE_NESTED_FDNUM) == PRIMAL) == (M == 0)
348367

349368
@test convert(Dual, FDNUM) === FDNUM
350369
@test convert(Dual, NESTED_FDNUM) === NESTED_FDNUM
@@ -395,6 +414,8 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
395414
#----------#
396415

397416
if M > 0 && N > 0
417+
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
418+
# all random numbers nonzero, and FDNUM2 another draw. M only affects NESTED_FDNUM.
398419
@test Dual{1}(FDNUM) / Dual{1}(PRIMAL) === Dual{1}(FDNUM / PRIMAL)
399420
@test Dual{1}(PRIMAL) / Dual{1}(FDNUM) === Dual{1}(PRIMAL / FDNUM)
400421
@test_broken Dual{1}(FDNUM) / FDNUM2 === Dual{1}(FDNUM / FDNUM2)
@@ -413,6 +434,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
413434

414435
# Exponentiation #
415436
#----------------#
437+
416438
# If V == Int, the LHS terms are Int's. Large inputs cause integer overflow
417439
# within the generic fallback of `isapprox`, resulting in a DomainError.
418440
# Promote to Float64 to avoid issues.
@@ -442,7 +464,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
442464
@test abs(NESTED_FDNUM) === NESTED_FDNUM
443465

444466
if V != Int
445-
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
467+
@testset "$f" for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
446468
if f in (:/, :rem2pi)
447469
continue # Skip these rules
448470
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
@@ -502,10 +524,14 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
502524
else
503525
@test dx isa Complex{<:Dual{TestTag()}}
504526
@test dy isa Complex{<:Dual{TestTag()}}
505-
@test real(value(dx)) == real(actualval)
506-
@test real(value(dy)) == real(actualval)
507-
@test imag(value(dx)) == imag(actualval)
508-
@test imag(value(dy)) == imag(actualval)
527+
# @test real(value(dx)) == real(actualval)
528+
# @test real(value(dy)) == real(actualval)
529+
# @test imag(value(dx)) == imag(actualval)
530+
# @test imag(value(dy)) == imag(actualval)
531+
@test value(real(dx)) == real(actualval)
532+
@test value(real(dy)) == real(actualval)
533+
@test value(imag(dx)) == imag(actualval)
534+
@test value(imag(dy)) == imag(actualval)
509535
@test partials(real(dx), 1) real(actualdx) nans=true
510536
@test partials(real(dy), 1) real(actualdy) nans=true
511537
@test partials(imag(dx), 1) imag(actualdx) nans=true
@@ -568,6 +594,10 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
568594
end
569595
end
570596

597+
#############
598+
# bug fixes #
599+
#############
600+
571601
@testset "Exponentiation of zero" begin
572602
x0 = 0.0
573603
x1 = Dual{:t1}(x0, 1.0)

test/GradientTest.jl

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module GradientTest
33
import Calculus
44

55
using Test
6+
using LinearAlgebra
67
using ForwardDiff
78
using ForwardDiff: Dual, Tag
89
using StaticArrays
@@ -19,7 +20,7 @@ x = [0.1, 0.2, 0.3]
1920
v = f(x)
2021
g = [-9.4, 15.6, 52.0]
2122

22-
for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
23+
@testset "Rosenbrock, chunk size = $c and tag = $(repr(tag))" for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
2324
println(" ...running hardcoded test with chunk size = $c and tag = $(repr(tag))")
2425
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{c}(), tag)
2526

@@ -55,7 +56,7 @@ cfgx = ForwardDiff.GradientConfig(sin, x)
5556
# test vs. Calculus.jl #
5657
########################
5758

58-
for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
59+
@testset "$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
5960
v = f(X)
6061
g = ForwardDiff.gradient(f, X)
6162
@test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR)
@@ -83,9 +84,9 @@ end
8384

8485
println(" ...testing specialized StaticArray codepaths")
8586

86-
x = rand(3, 3)
87+
@testset "$T" for T in (StaticArrays.SArray, StaticArrays.MArray)
88+
x = rand(3, 3)
8789

88-
for T in (StaticArrays.SArray, StaticArrays.MArray)
8990
sx = T{Tuple{3,3}}(x)
9091

9192
cfg = ForwardDiff.GradientConfig(nothing, x)
@@ -148,6 +149,10 @@ end
148149
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
149150
end
150151

152+
#############
153+
# bug fixes #
154+
#############
155+
151156
# Issue 399
152157
@testset "chunk size zero" begin
153158
f_const(x) = 1.0
@@ -162,11 +167,55 @@ end
162167
@test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient
163168
end
164169

170+
# Issue 548
165171
@testset "ArithmeticStyle" begin
166172
function f(p)
167173
sum(collect(0.0:p[1]:p[2]))
168174
end
169175
@test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0]
170176
end
171177

178+
@testset "det with branches" begin
179+
# Issue 197
180+
det2(A) = return (
181+
A[1,1]*(A[2,2]*A[3,3]-A[2,3]*A[3,2]) -
182+
A[1,2]*(A[2,1]*A[3,3]-A[2,3]*A[3,1]) +
183+
A[1,3]*(A[2,1]*A[3,2]-A[2,2]*A[3,1])
184+
)
185+
186+
A = [1 0 0; 0 2 0; 0 pi 3]
187+
@test det2(A) == det(A) == 6
188+
@test istril(A)
189+
190+
∇A = [6 0 0; 0 3 -pi; 0 0 2]
191+
@test ForwardDiff.gradient(det2, A) ∇A
192+
@test ForwardDiff.gradient(det, A) ∇A
193+
194+
# And issue 407
195+
@test ForwardDiff.hessian(det, A) ForwardDiff.hessian(det2, A)
196+
197+
# https://discourse.julialang.org/t/forwarddiff-and-zygote-return-wrong-jacobian-for-log-det-l/77961
198+
S = [1.0 0.8; 0.8 1.0]
199+
L = cholesky(S).L
200+
@test ForwardDiff.gradient(L -> log(det(L)), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
201+
@test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
202+
end
203+
204+
@testset "branches in mul!" begin
205+
a, b = rand(3,3), rand(3,3)
206+
207+
# Issue 536, version with 3-arg *, Julia 1.7:
208+
@test ForwardDiff.derivative(x -> sum(x*a*b), 0.0) sum(a * b)
209+
210+
if VERSION >= v"1.3"
211+
# version with just mul!
212+
dx = ForwardDiff.derivative(0.0) do x
213+
c = similar(a, typeof(x))
214+
mul!(c, a, b, x, false)
215+
sum(c)
216+
end
217+
@test dx sum(a * b)
218+
end
219+
end
220+
172221
end # module

test/HessianTest.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module HessianTest
33
import Calculus
44

55
using Test
6+
using LinearAlgebra
67
using ForwardDiff
78
using ForwardDiff: Dual, Tag
89
using StaticArrays
@@ -157,4 +158,11 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
157158
@test DiffResults.hessian(sresult3) == DiffResults.hessian(result)
158159
end
159160

161+
@testset "branches in dot" begin
162+
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/551
163+
H = [1 2 3; 4 5 6; 7 8 9];
164+
@test ForwardDiff.hessian(x->dot(x,H,x), fill(0.00001, 3)) [2 6 10; 6 10 14; 10 14 18]
165+
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) [2 6 10; 6 10 14; 10 14 18]
166+
end
167+
160168
end # module

test/JacobianTest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
226226
@test DiffResults.jacobian(sresult3) == DiffResults.jacobian(result)
227227
end
228228

229+
#########
230+
# misc. #
231+
#########
232+
229233
@testset "dimension errors for jacobian" begin
230234
@test_throws DimensionMismatch ForwardDiff.jacobian(identity, 2pi) # input
231235
@test_throws DimensionMismatch ForwardDiff.jacobian(sum, fill(2pi, 2)) # vector_mode_jacobian

test/PartialsTest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using ForwardDiff: Partials
77

88
samerng() = MersenneTwister(1)
99

10-
for N in (0, 3), T in (Int, Float32, Float64)
10+
@testset "Partials{$N,$T}" for N in (0, 3), T in (Int, Float32, Float64)
1111
println(" ...testing Partials{$N,$T}")
1212

1313
VALUES = (rand(T,N)...,)

0 commit comments

Comments
 (0)