Skip to content

Commit 0e8e475

Browse files
authored
Add rrule for multiarg array addition (#462)
1 parent df00ef8 commit 0e8e475

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.20"
3+
version = "0.8.21"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/arraymath.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,17 @@ function rrule(::typeof(-), x::AbstractArray)
268268
end
269269
return -x, negation_pullback
270270
end
271+
272+
273+
#####
274+
##### Addition (Multiarg `+`)
275+
#####
276+
277+
function rrule(::typeof(+), arrs::AbstractArray...)
278+
y = +(arrs...)
279+
arr_axs = map(axes, arrs)
280+
function add_pullback(dy)
281+
return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
282+
end
283+
return y, add_pullback
284+
end

test/rulesets/Base/arraymath.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,9 @@
148148
test_rrule(-, A)
149149
test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā))
150150
end
151+
152+
@testset "addition" begin
153+
test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
154+
test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
155+
end
151156
end

0 commit comments

Comments
 (0)