Skip to content

Commit aafb664

Browse files
committed
make rrule return identical pullback for zero as for one
Could be a minor compilation latency and/or type stability win for some uses.
1 parent 2c6621c commit aafb664

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/rulesets/Base/base.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
@scalar_rule copysign(y, x) (ifelse(signbit(x)!=signbit(y), -one(y), +one(y)), NoTangent())
55
@scalar_rule transpose(x) true
66

7+
# TODO: define using `Returns((NoTangent(), ZeroTangent()))` when support for Julia v1.6 is dropped
8+
function _pullback_for_constant(::Any)
9+
(NoTangent(), ZeroTangent())
10+
end
11+
712
# `zero`
813

914
function frule((_, _), ::typeof(zero), x)
1015
return (zero(x), ZeroTangent())
1116
end
1217

1318
function rrule(::typeof(zero), x)
14-
zero_pullback(_) = (NoTangent(), ZeroTangent())
15-
return (zero(x), zero_pullback)
19+
return (zero(x), _pullback_for_constant)
1620
end
1721

1822
# `one`
@@ -22,8 +26,7 @@ function frule((_, _), ::typeof(one), x)
2226
end
2327

2428
function rrule(::typeof(one), x)
25-
one_pullback(_) = (NoTangent(), ZeroTangent())
26-
return (one(x), one_pullback)
29+
return (one(x), _pullback_for_constant)
2730
end
2831

2932

test/rulesets/Base/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ end
44

55
@testset "base.jl" begin
66
@testset "zero/one" begin
7+
@test last(rrule(zero, 0.1)) === last(rrule(one, 0.2f0))
78
for f in [zero, one]
89
for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]]
910
test_frule(f, x)

0 commit comments

Comments
 (0)