Skip to content

Commit 9dd39bd

Browse files
authored
Merge pull request #797 from nsajko/zero_one_identically_same_pullback
make `rrule` return identical pullback for `zero` as for `one`
2 parents 2c6621c + 363aa6c commit 9dd39bd

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.66.0"
3+
version = "1.67.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function frule((_, _), ::typeof(zero), x)
1111
end
1212

1313
function rrule(::typeof(zero), x)
14-
zero_pullback(_) = (NoTangent(), ZeroTangent())
14+
zero_pullback = Returns((NoTangent(), ZeroTangent()))
1515
return (zero(x), zero_pullback)
1616
end
1717

@@ -22,7 +22,7 @@ function frule((_, _), ::typeof(one), x)
2222
end
2323

2424
function rrule(::typeof(one), x)
25-
one_pullback(_) = (NoTangent(), ZeroTangent())
25+
one_pullback = Returns((NoTangent(), ZeroTangent()))
2626
return (one(x), one_pullback)
2727
end
2828

test/rulesets/Base/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ end
44

55
@testset "base.jl" begin
66
@testset "zero/one" begin
7+
@test last(rrule(zero, 0.1)) === last(rrule(one, 0.2f0))
78
for f in [zero, one]
89
for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]]
910
test_frule(f, x)

0 commit comments

Comments
 (0)