Skip to content

Commit 4ad9758

Browse files
authored
Merge pull request #471 from JuliaDiff/mz/ithunk
Change argument order in InplaceableThunk
2 parents 6d6fa9e + c36e5ad commit 4ad9758

File tree

9 files changed

+80
-81
lines changed

9 files changed

+80
-81
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.22"
3+
version = "0.8.23"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,7 +10,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
ChainRulesCore = "0.10.9"
13+
ChainRulesCore = "0.10.12"
1414
ChainRulesTestUtils = "0.7.9"
1515
Compat = "3.31"
1616
FiniteDifferences = "0.12.8"

src/rulesets/Base/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
9797
end
9898
if ndimsX > 0
9999
# Here InplaceableThunk breaks @inferred, removed for now
100-
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
100+
# InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
101101
dY[ind...]
102102
else
103103
# This is a hack to perhaps avoid GPU scalar indexing

src/rulesets/Base/arraymath.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ function rrule(
2929
return (
3030
NoTangent(),
3131
InplaceableThunk(
32+
-> mul!(X̄, Ȳ, B', true, true),
3233
@thunk(Ȳ * B'),
33-
-> mul!(X̄, Ȳ, B', true, true)
3434
),
3535
InplaceableThunk(
36+
-> mul!(X̄, A', Ȳ, true, true),
3637
@thunk(A' * Ȳ),
37-
-> mul!(X̄, A', Ȳ, true, true)
3838
)
3939
)
4040
end
@@ -52,12 +52,12 @@ function rrule(
5252
return (
5353
NoTangent(),
5454
InplaceableThunk(
55+
-> mul!(X̄, Ȳ, vec(B'), true, true),
5556
@thunk(Ȳ * vec(B')),
56-
-> mul!(X̄, Ȳ, vec(B'), true, true)
5757
),
5858
InplaceableThunk(
59+
-> mul!(X̄, A', Ȳ, true, true),
5960
@thunk(A' * Ȳ),
60-
-> mul!(X̄, A', Ȳ, true, true)
6161
)
6262
)
6363
end
@@ -73,8 +73,8 @@ function rrule(
7373
NoTangent(),
7474
@thunk(dot(Ȳ, B)'),
7575
InplaceableThunk(
76+
-> mul!(X̄, conj(A), Ȳ, true, true),
7677
@thunk(A' * Ȳ),
77-
-> mul!(X̄, conj(A), Ȳ, true, true)
7878
)
7979
)
8080
end
@@ -89,8 +89,8 @@ function rrule(
8989
return (
9090
NoTangent(),
9191
InplaceableThunk(
92+
-> mul!(X̄, conj(A), Ȳ, true, true),
9293
@thunk(A' * Ȳ),
93-
-> mul!(X̄, conj(A), Ȳ, true, true)
9494
),
9595
@thunk(dot(Ȳ, B)'),
9696
)
@@ -114,12 +114,12 @@ function rrule(
114114
= unthunk(ȳ)
115115
matmul = (
116116
InplaceableThunk(
117+
dA -> mul!(dA, Ȳ, B', true, true),
117118
@thunk(Ȳ * B'),
118-
dA -> mul!(dA, Ȳ, B', true, true)
119119
),
120120
InplaceableThunk(
121+
dB -> mul!(dB, A', Ȳ, true, true),
121122
@thunk(A' * Ȳ),
122-
dB -> mul!(dB, A', Ȳ, true, true)
123123
)
124124
)
125125
addon = if z isa Bool
@@ -128,8 +128,8 @@ function rrule(
128128
@thunk(sum(Ȳ))
129129
else
130130
InplaceableThunk(
131+
dz -> sum!(dz, Ȳ; init=false),
131132
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
132-
dz -> sum!(dz, Ȳ; init=false)
133133
)
134134
end
135135
(NoTangent(), matmul..., addon)
@@ -147,12 +147,12 @@ function rrule(
147147
function muladd_pullback_2(ȳ)
148148
dy = unthunk(ȳ)
149149
ut_thunk = InplaceableThunk(
150+
dut -> dut .+= v' .* dy,
150151
@thunk(v' .* dy),
151-
dut -> dut .+= v' .* dy
152152
)
153153
v_thunk = InplaceableThunk(
154+
dv -> dv .+= ut' .* dy,
154155
@thunk(ut' .* dy),
155-
dv -> dv .+= ut' .* dy
156156
)
157157
(NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : dy)
158158
end
@@ -178,8 +178,8 @@ function rrule(
178178
@thunk(sum(Ȳ))
179179
else
180180
InplaceableThunk(
181+
dz -> sum!(dz, Ȳ; init=false),
181182
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
182-
dz -> sum!(dz, Ȳ; init=false)
183183
)
184184
end
185185
(NoTangent(), proj..., addon)
@@ -240,8 +240,8 @@ function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::Commuta
240240
function slash_pullback_scalar(ȳ)
241241
= unthunk(ȳ)
242242
Athunk = InplaceableThunk(
243-
@thunk(Ȳ / conj(b)),
244243
dA -> dA .+=./ conj(b),
244+
@thunk(Ȳ / conj(b)),
245245
)
246246
bthunk = @thunk(-dot(A,Ȳ) / conj(b^2))
247247
return (NoTangent(), Athunk, bthunk)
@@ -264,7 +264,7 @@ end
264264

265265
function rrule(::typeof(-), x::AbstractArray)
266266
function negation_pullback(ȳ)
267-
return NoTangent(), InplaceableThunk(@thunk(-ȳ), ->.-= ȳ)
267+
return NoTangent(), InplaceableThunk(ā ->.-=, @thunk(-ȳ))
268268
end
269269
return -x, negation_pullback
270270
end

src/rulesets/Base/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
1717
end
1818

1919
= InplaceableThunk(
20+
getindex_add!,
2021
@thunk(getindex_add!(zero(x))),
21-
getindex_add!
2222
)
2323
īnds = broadcast(_ -> NoTangent(), inds)
2424
return (NoTangent(), x̄, īnds...)

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
1111
function sum_pullback(ȳ)
1212
# broadcasting the two works out the size no-matter `dims`
1313
= InplaceableThunk(
14+
x -> x .+= ȳ,
1415
@thunk(broadcast(lasttuple, x, ȳ)),
15-
x -> x .+=
1616
)
1717
return (NoTangent(), x̄)
1818
end
@@ -93,8 +93,8 @@ function rrule(
9393
y = sum(abs2, x; dims=dims)
9494
function sum_abs2_pullback(ȳ)
9595
x_thunk = InplaceableThunk(
96+
dx -> dx .+= 2 .* real.(ȳ) .* x,
9697
@thunk(2 .* real.(ȳ) .* x),
97-
dx -> dx .+= 2 .* real.(ȳ) .* x
9898
)
9999
return (NoTangent(), NoTangent(), x_thunk)
100100
end
@@ -122,16 +122,6 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ
122122
function prod_pullback(ȳ)
123123
dy = unthunk(ȳ)
124124
x_thunk = InplaceableThunk(
125-
# Out-of-place versions
126-
@thunk if dims === (:)
127-
∇prod(x, dy, y)
128-
elseif any(iszero, x) # Then, and only then, will ./x lead to NaN
129-
vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims))
130-
∇prod_dims(vald, x, dy, y) # val(Int(dims)) is about 2x faster than Val(Tuple(dims))
131-
else
132-
conj.(y ./ x) .* dy
133-
end
134-
,
135125
# In-place versions -- same branching
136126
dx -> if dims === (:)
137127
∇prod!(dx, x, dy, y)
@@ -140,6 +130,15 @@ function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:Commutativ
140130
∇prod_dims!(dx, vald, x, dy, y)
141131
else
142132
dx .+= conj.(y ./ x) .* dy
133+
end,
134+
# Out-of-place versions
135+
@thunk if dims === (:)
136+
∇prod(x, dy, y)
137+
elseif any(iszero, x) # Then, and only then, will ./x lead to NaN
138+
vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims))
139+
∇prod_dims(vald, x, dy, y) # val(Int(dims)) is about 2x faster than Val(Tuple(dims))
140+
else
141+
conj.(y ./ x) .* dy
143142
end
144143
)
145144
return (NoTangent(), x_thunk)

src/rulesets/Base/sort.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin
88
return Δxs
99
end
1010

11-
Δxs = InplaceableThunk(@thunk(partialsort_add!(zero(xs))), partialsort_add!)
11+
Δxs = InplaceableThunk(partialsort_add!, @thunk(partialsort_add!(zero(xs))))
1212

1313
return NoTangent(), Δxs, NoTangent()
1414
end
@@ -27,7 +27,7 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...)
2727
return Δxs
2828
end
2929

30-
Δxs = InplaceableThunk(@thunk(sort_add!(zero(Δys))), sort_add!)
30+
Δxs = InplaceableThunk(sort_add!, @thunk(sort_add!(zero(Δys))))
3131

3232
return NoTangent(), Δxs
3333
end

src/rulesets/LinearAlgebra/blas.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,30 @@ function rrule(::typeof(gemv), tA::Char, α::T, A::AbstractMatrix{T},
111111
= unthunk(Ȳ)
112112
if uppercase(tA) === 'N'
113113
∂A = InplaceableThunk(
114+
-> ger!', ȳ, x, Ā),
114115
@thunk' ** x'),
115-
-> ger!', ȳ, x, Ā)
116116
)
117117
∂x = InplaceableThunk(
118+
-> gemv!('C', α', A, ȳ, one(T), x̄),
118119
@thunk(gemv('C', α', A, ȳ)),
119-
-> gemv!('C', α', A, ȳ, one(T), x̄)
120120
)
121121
elseif uppercase(tA) === 'C'
122122
∂A = InplaceableThunk(
123+
-> ger!(α, x, ȳ, Ā),
123124
@thunk* x *'),
124-
-> ger!(α, x, ȳ, Ā)
125125
)
126126
∂x = InplaceableThunk(
127+
-> gemv!('N', α', A, ȳ, one(T), x̄),
127128
@thunk(gemv('N', α', A, ȳ)),
128-
-> gemv!('N', α', A, ȳ, one(T), x̄)
129129
)
130130
else # uppercase(tA) === 'T'
131131
∂A = InplaceableThunk(
132+
-> conj!(ger!(α, x, ȳ, conj!(Ā))),
132133
@thunk(conj* x *')),
133-
-> conj!(ger!(α, x, ȳ, conj!(Ā)))
134134
)
135135
∂x = InplaceableThunk(
136+
-> gemv!('N', α', conj(A), ȳ, one(T), x̄),
136137
@thunk(gemv('N', α', conj(A), ȳ)),
137-
-> gemv!('N', α', conj(A), ȳ, one(T), x̄)
138138
)
139139
end
140140
return (NoTangent(), NoTangent(), @thunk(dot(y, ȳ) / α'), ∂A, ∂x)
@@ -167,88 +167,88 @@ function rrule(
167167
if uppercase(tA) === 'N'
168168
if uppercase(tB) === 'N'
169169
∂A = InplaceableThunk(
170+
-> gemm!('N', 'C', α', C̄, B, β, Ā),
170171
@thunk(gemm('N', 'C', α', C̄, B)),
171-
-> gemm!('N', 'C', α', C̄, B, β, Ā)
172172
)
173173
∂B = InplaceableThunk(
174+
-> gemm!('C', 'N', α', A, C̄, β, B̄),
174175
@thunk(gemm('C', 'N', α', A, C̄)),
175-
-> gemm!('C', 'N', α', A, C̄, β, B̄)
176176
)
177177
elseif uppercase(tB) === 'C'
178178
∂A = InplaceableThunk(
179+
-> gemm!('N', 'N', α', C̄, B, β, Ā),
179180
@thunk(gemm('N', 'N', α', C̄, B)),
180-
-> gemm!('N', 'N', α', C̄, B, β, Ā)
181181
)
182182
∂B = InplaceableThunk(
183+
-> gemm!('C', 'N', α, C̄, A, β, B̄),
183184
@thunk(gemm('C', 'N', α, C̄, A)),
184-
-> gemm!('C', 'N', α, C̄, A, β, B̄)
185185
)
186186
else # uppercase(tB) === 'T'
187187
∂A = InplaceableThunk(
188+
-> gemm!('N', 'N', α', C̄, conj(B), β, Ā),
188189
@thunk(gemm('N', 'N', α', C̄, conj(B))),
189-
-> gemm!('N', 'N', α', C̄, conj(B), β, Ā)
190190
)
191191
∂B = InplaceableThunk(
192+
-> conj!(gemm!('C', 'N', α, C̄, A, β, conj!(B̄))),
192193
@thunk(conj(gemm('C', 'N', α, C̄, A))),
193-
-> conj!(gemm!('C', 'N', α, C̄, A, β, conj!(B̄)))
194194
)
195195
end
196196
elseif uppercase(tA) === 'C'
197197
if uppercase(tB) === 'N'
198198
∂A = InplaceableThunk(
199+
-> gemm!('N', 'C', α, B, C̄, β, Ā),
199200
@thunk(gemm('N', 'C', α, B, C̄)),
200-
-> gemm!('N', 'C', α, B, C̄, β, Ā)
201201
)
202202
∂B = InplaceableThunk(
203+
-> gemm!('N', 'N', α', A, C̄, β, B̄),
203204
@thunk(gemm('N', 'N', α', A, C̄)),
204-
-> gemm!('N', 'N', α', A, C̄, β, B̄)
205205
)
206206
elseif uppercase(tB) === 'C'
207207
∂A = InplaceableThunk(
208+
-> gemm!('C', 'C', α, B, C̄, β, Ā),
208209
@thunk(gemm('C', 'C', α, B, C̄)),
209-
-> gemm!('C', 'C', α, B, C̄, β, Ā)
210210
)
211211
∂B = InplaceableThunk(
212+
-> gemm!('C', 'C', α, C̄, A, β, B̄),
212213
@thunk(gemm('C', 'C', α, C̄, A)),
213-
-> gemm!('C', 'C', α, C̄, A, β, B̄)
214214
)
215215
else # uppercase(tB) === 'T'
216216
∂A = InplaceableThunk(
217+
-> gemm!('T', 'C', α, B, C̄, β, Ā),
217218
@thunk(gemm('T', 'C', α, B, C̄)),
218-
-> gemm!('T', 'C', α, B, C̄, β, Ā)
219219
)
220220
∂B = InplaceableThunk(
221+
-> gemm!('T', 'T', α', C̄, A, β, B̄),
221222
@thunk(gemm('T', 'T', α', C̄, A)),
222-
-> gemm!('T', 'T', α', C̄, A, β, B̄)
223223
)
224224
end
225225
else # uppercase(tA) === 'T'
226226
if uppercase(tB) === 'N'
227227
∂A = InplaceableThunk(
228+
-> conj!(gemm!('N', 'C', α, B, C̄, β, conj!(Ā))),
228229
@thunk(conj(gemm('N', 'C', α, B, C̄))),
229-
-> conj!(gemm!('N', 'C', α, B, C̄, β, conj!(Ā)))
230230
)
231231
∂B = InplaceableThunk(
232+
-> gemm!('N', 'N', α', conj(A), C̄, β, B̄),
232233
@thunk(gemm('N', 'N', α', conj(A), C̄)),
233-
-> gemm!('N', 'N', α', conj(A), C̄, β, B̄)
234234
)
235235
elseif uppercase(tB) === 'C'
236236
∂A = InplaceableThunk(
237+
-> gemm!('T', 'T', α', B, C̄, β, Ā),
237238
@thunk(gemm('T', 'T', α', B, C̄)),
238-
-> gemm!('T', 'T', α', B, C̄, β, Ā)
239239
)
240240
∂B = InplaceableThunk(
241+
-> gemm!('C', 'T', α, C̄, A, β, B̄),
241242
@thunk(gemm('C', 'T', α, C̄, A)),
242-
-> gemm!('C', 'T', α, C̄, A, β, B̄)
243243
)
244244
else # uppercase(tB) === 'T'
245245
∂A = InplaceableThunk(
246+
-> gemm!('C', 'T', α', B, C̄, β, Ā),
246247
@thunk(gemm('C', 'T', α', B, C̄)),
247-
-> gemm!('C', 'T', α', B, C̄, β, Ā)
248248
)
249249
∂B = InplaceableThunk(
250+
-> gemm!('T', 'C', α', C̄, A, β, B̄),
250251
@thunk(gemm('T', 'C', α', C̄, A)),
251-
-> gemm!('T', 'C', α', C̄, A, β, B̄)
252252
)
253253
end
254254
end

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ function rrule(::typeof(dot), x::AbstractArray, y::AbstractArray)
1010
function dot_pullback(Ω̄)
1111
ΔΩ = unthunk(Ω̄)
1212
xthunk = InplaceableThunk(
13-
@thunk(reshape(y .* ΔΩ', axes(x))),
1413
dx -> dx .+= reshape(y, axes(x)) .* ΔΩ',
14+
@thunk(reshape(y .* ΔΩ', axes(x))),
1515
)
1616
ythunk = InplaceableThunk(
17-
@thunk(reshape(x .* ΔΩ, axes(y))),
1817
dy -> dy .+= reshape(x, axes(y)) .* ΔΩ,
18+
@thunk(reshape(x .* ΔΩ, axes(y))),
1919
)
2020
return (NoTangent(), xthunk, ythunk)
2121
end

0 commit comments

Comments
 (0)