Skip to content

Commit 36b5c57

Browse files
committed
Add extension
1 parent ee83280 commit 36b5c57

File tree

4 files changed

+59
-69
lines changed

4 files changed

+59
-69
lines changed

Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,28 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1919
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2020

21+
[weakdeps]
22+
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
23+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
24+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
25+
26+
[extensions]
27+
ChainRulesKernelAbstractionsExt = ["Atomix", "GPUArrays", "KernelAbstractions"]
28+
2129
[compat]
2230
Adapt = "3.4.0, 4"
31+
Atomix = "0.1"
2332
ChainRulesCore = "1.25"
2433
ChainRulesTestUtils = "1.5"
2534
Compat = "3.46, 4.2"
2635
Distributed = "1"
2736
FiniteDifferences = "0.12.20"
2837
GPUArraysCore = "0.1.0, 0.2"
38+
GPUArrays = "10, 11"
2939
IrrationalConstants = "0.1.1, 0.2"
3040
JLArrays = "0.1"
3141
JuliaInterpreter = "0.8,0.9"
42+
KernelAbstractions = "0.9"
3243
LinearAlgebra = "1"
3344
Random = "1"
3445
RealDot = "0.1"
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module ChainRulesKernelAbstractionsExt
2+
3+
import Adapt
4+
import Atomix
5+
import ChainRules
6+
import GPUArrays
7+
import KernelAbstractions as KA
8+
9+
using GPUArraysCore: AbstractGPUArray
10+
using KernelAbstractions
11+
12+
function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...)
13+
# kab = get_backend(dx)
14+
15+
# if KA.supports_atomics(kab)
16+
# gids = GPUArrays.to_indices(dx, inds)
17+
# idims = map(length, gids)
18+
# Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
19+
# scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
20+
# else
21+
dx_cpu = Adapt.adapt(Array, dx)
22+
view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy)
23+
copyto!(dx, dx_cpu)
24+
# end
25+
return dx
26+
end
27+
28+
@kernel function scatter!(op, dest, src, idims, Is::Vararg{Any, N}) where N
29+
_scatter!(@index(Global), op, dest, src, idims, Is...)
30+
end
31+
32+
@generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N
33+
quote
34+
is = @inbounds CartesianIndices(idims)[i]
35+
Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]])
36+
dv = src[i]
37+
Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j
38+
end
39+
end
40+
41+
function _accum!(op, dest, val, ids...)
42+
Atomix.modify!(Atomix.IndexableRef(dest, (ids...,)), op, val)
43+
end
44+
45+
end

src/rulesets/Base/indexing.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ function rrule(::typeof(∇getindex), x, dy, inds...)
168168
return z, ∇getindex_pullback
169169
end
170170

171-
# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
172-
# To avoid this, copy everything back to the CPU.
173-
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
171+
# NOTE:
172+
# Generic `∇getindex!(dx::AbstractGPUArray, dy, inds...)`
173+
# is implemented in `ext/` with a custom kernel.
174174

175175
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...)
176176
view(dx, inds...) .+= Ref(dy)
@@ -181,17 +181,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni
181181
return dx
182182
end
183183

184-
function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
185-
# TODO we want this
186-
# @atomic dx[inds...] .+= dy
187-
# return dx
188-
189-
dx_cpu = adapt(Array, dx)
190-
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
191-
copyto!(dx, dx_cpu)
192-
return dx
193-
end
194-
195184
#####
196185
##### view
197186
#####

t.jl

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)