Skip to content

Commit 93eb113

Browse files
AlexRobsonAlex Robson
andauthored
Add repeat rules from Zygote (#460)
* Add repeat rules from Zygote. Add tests * Bump version * Restore original typing and defaults. Take size arguments out of closure * Add unthunk * Handle positional arguments with defaults. Mark zero arrays as broken * Revise comment on broken tests * Work through some test edge cases on 1.0 Co-authored-by: Alex Robson <[email protected]>
1 parent bbb88f7 commit 93eb113

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.17"
3+
version = "0.8.18"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,59 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
2020
return reshape(A, dims...), reshape_pullback
2121
end
2222

23+
#####
24+
##### `repeat`
25+
#####
26+
27+
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
28+
29+
function repeat_pullback(ȳ)
30+
dY = unthunk(ȳ)
31+
Δ′ = zero(xs)
32+
S = size(xs)
33+
34+
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
35+
for (dest_idx, val) in pairs(IndexCartesian(), dY)
36+
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
37+
# wrap around based on original size S.
38+
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
39+
Δ′[src_idx...] += val
40+
end
41+
return (NoTangent(), Δ′)
42+
end
43+
44+
return repeat(xs; inner = inner, outer = outer), repeat_pullback
45+
end
46+
47+
function rrule(::typeof(repeat), xs::AbstractVector, m::Integer)
48+
49+
d1 = size(xs, 1)
50+
function repeat_pullback(ȳ)
51+
Δ′ = dropdims(sum(reshape(ȳ, d1, :); dims=2); dims=2)
52+
return (NoTangent(), Δ′, NoTangent())
53+
end
54+
55+
return repeat(xs, m), repeat_pullback
56+
end
57+
58+
function rrule(::typeof(repeat), xs::AbstractVecOrMat, m::Integer, n::Integer)
59+
d1, d2 = size(xs, 1), size(xs, 2)
60+
function repeat_pullback(ȳ)
61+
ȳ′ = reshape(ȳ, d1, m, d2, n)
62+
return NoTangent(), reshape(sum(ȳ′; dims=(2,4)), (d1, d2)), NoTangent(), NoTangent()
63+
end
64+
65+
return repeat(xs, m, n), repeat_pullback
66+
end
67+
68+
function rrule(T::typeof(repeat), xs::AbstractVecOrMat, m::Integer)
69+
70+
# Workaround use of positional default (i.e. repeat(xs, m, n = 1)))
71+
y, full_pb = rrule(T, xs, m, 1)
72+
repeat_pullback(ȳ) = full_pb(ȳ)[1:3]
73+
return y, repeat_pullback
74+
end
75+
2376
#####
2477
##### `hcat`
2578
#####

test/rulesets/Base/array.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,34 @@
44
test_rrule(reshape, rand(4, 5), 2, :)
55
end
66

7+
@testset "repeat" begin
8+
test_rrule(repeat, rand(4, ))
9+
test_rrule(repeat, rand(4, ), 2)
10+
test_rrule(repeat, rand(4, 5))
11+
test_rrule(repeat, rand(4, 5); fkwargs = (outer=(1,2),))
12+
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(1,2), outer=(1,3)))
13+
14+
if VERSION>=v"1.6"
15+
# repeat([1 2; 3 4], inner=(2,4), outer=(1,1,1,3)) fails for v<1.6
16+
test_rrule(repeat, rand(4, 5); fkwargs = (inner=(2,4), outer=(1,1,1,3)))
17+
end
18+
test_rrule(repeat, rand(4, 5), 2; check_inferred=VERSION>=v"1.5")
19+
test_rrule(repeat, rand(4, 5), 2, 3)
20+
21+
# zero-arrays: broken
22+
@test_broken rrule(repeat, fill(1.0), 2) !== nothing
23+
@test_broken rrule(repeat, fill(1.0), 2, 3) !== nothing
24+
25+
# These dispatch but probably needs
26+
# https://github.com/JuliaDiff/FiniteDifferences.jl/issues/179
27+
# test_rrule(repeat, fill(1.0); fkwargs = (inner=2,))
28+
# test_rrule(repeat, fill(1.0); fkwargs = (inner=2, outer=3,))
29+
30+
@test rrule(repeat, [1,2,3], 4)[2](ones(12))[2] == [4,4,4]
31+
@test rrule(repeat, [1,2,3], outer=4)[2](ones(12))[2] == [4,4,4]
32+
33+
end
34+
735
@testset "hcat" begin
836
test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1")
937
test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1")

0 commit comments

Comments
 (0)