Skip to content

Commit ee83280

Browse files
committed
Add scatter
1 parent e055009 commit ee83280

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,12 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni
180180
view(dx, inds...) .+= dy
181181
return dx
182182
end
183+
183184
function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
185+
# TODO we want this
186+
# @atomic dx[inds...] .+= dy
187+
# return dx
188+
184189
dx_cpu = adapt(Array, dx)
185190
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
186191
copyto!(dx, dx_cpu)

t.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using ChainRules
2+
using GPUArrays
3+
using Zygote
4+
using AMDGPU
5+
using KernelAbstractions
6+
using KernelAbstractions: @atomic
7+
8+
function _accum!(dest, val, ids...)
9+
# TODO support passing `op`
10+
@atomic dest[ids...] += val
11+
end
12+
13+
@generated function _scatter!(i, dest, src, idims, Is::Vararg{Any, N}) where N
14+
quote
15+
is = @inbounds CartesianIndices(idims)[i]
16+
Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]])
17+
dv = dest[i]
18+
Base.Cartesian.@ncall $N _accum! src dv j -> I_j
19+
end
20+
end
21+
22+
@kernel function scatter!(dest, src, idims, Is::Vararg{Any, N}) where N
23+
_scatter!(@index(Global), dest, src, idims, Is...)
24+
end
25+
26+
function main()
27+
x = ROCArray(zeros(Float32, 16, 4, 2, 3))
28+
y = ROCArray(ones(Float32, 6, 2, 2))
29+
ids = ([4, 1, 4, 3, 2, 1], 1, :, 3)
30+
31+
gids = GPUArrays.to_indices(x, ids)
32+
idims = map(length, gids)
33+
Is = map(AMDGPU.Adapt.adapt(GPUArrays.ToGPU(y)), gids)
34+
35+
kab = get_backend(x)
36+
scatter!(kab, 256)(y, x, idims, Is...; ndrange=length(y))
37+
@show y
38+
@show Array(x)[:, 1, 1, 3]
39+
40+
# @show x[ids...]
41+
# x[ids...] .+= y
42+
# return
43+
44+
# Δ = ROCArray(ones(Float32, 1))
45+
46+
# y, back = Zygote.pullback(x) do x
47+
# # xd = x[[4, 3, 2, 1], :, 1, [3, 1]]
48+
# xd = x[]
49+
# sum(xd; dims=(1:ndims(xd)...,))
50+
# end
51+
# println("===============")
52+
# back(Δ)
53+
return
54+
end
55+
main()

0 commit comments

Comments
 (0)