Skip to content

Commit 4d4c70f

Browse files
committed
Cleanup
1 parent 36b5c57 commit 4d4c70f

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ using GPUArraysCore: AbstractGPUArray
1010
using KernelAbstractions
1111

1212
function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...)
13-
# kab = get_backend(dx)
14-
15-
# if KA.supports_atomics(kab)
16-
# gids = GPUArrays.to_indices(dx, inds)
17-
# idims = map(length, gids)
18-
# Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
19-
# scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
20-
# else
13+
kab = get_backend(dx)
14+
15+
if KA.supports_atomics(kab)
16+
gids = GPUArrays.to_indices(dx, inds)
17+
idims = map(length, gids)
18+
Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
19+
scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
20+
else
2121
dx_cpu = Adapt.adapt(Array, dx)
2222
view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy)
2323
copyto!(dx, dx_cpu)
24-
# end
24+
end
2525
return dx
2626
end
2727

0 commit comments

Comments
 (0)