Skip to content

Commit bb2c6e2

Browse files
authored
fix: hotfix for breaking changes in GPUArrays (#282)
1 parent b1aa851 commit bb2c6e2

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1616

1717
[weakdeps]
1818
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
19+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1920
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2021
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2122
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
@@ -25,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526

2627
[extensions]
2728
ComponentArraysGPUArraysExt = "GPUArrays"
29+
ComponentArraysKernelAbstractionsExt = "KernelAbstractions"
2830
ComponentArraysOptimisersExt = "Optimisers"
2931
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
3032
ComponentArraysReverseDiffExt = "ReverseDiff"
@@ -40,6 +42,7 @@ ConstructionBase = "1"
4042
ForwardDiff = "0.10.36"
4143
Functors = "0.4.12, 0.5"
4244
GPUArrays = "10, 11"
45+
KernelAbstractions = "0.9.29"
4346
LinearAlgebra = "1.10"
4447
Optimisers = "0.3, 0.4"
4548
RecursiveArrayTools = "3.8"

ext/ComponentArraysGPUArraysExt.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,24 @@ const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVecto
88
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
99
const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}}
1010

11-
GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
11+
@static if pkgversion(GPUArrays) < v"11"
12+
GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x))
1213

13-
function Base.fill!(A::GPUComponentArray{T}, x) where {T}
14-
length(A) == 0 && return A
15-
GPUArrays.gpu_call(A, convert(T, x)) do ctx, a, val
16-
idx = GPUArrays.@linearidx(a)
17-
@inbounds a[idx] = val
18-
return
14+
function Base.fill!(A::GPUComponentArray{T}, x) where {T}
15+
length(A) == 0 && return A
16+
GPUArrays.gpu_call(A, convert(T, x)) do ctx, a, val
17+
idx = GPUArrays.@linearidx(a)
18+
@inbounds a[idx] = val
19+
return
20+
end
21+
return A
22+
end
23+
else
24+
function Base.fill!(A::GPUComponentArray{T}, x) where {T}
25+
length(A) == 0 && return A
26+
ComponentArrays.fill_componentarray_ka!(A, x)
27+
return A
1928
end
20-
A
2129
end
2230

2331
LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y))
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module ComponentArraysKernelAbstractionsExt
2+
3+
using ComponentArrays: ComponentArrays, ComponentArray
4+
using KernelAbstractions: KernelAbstractions, @kernel, @index
5+
6+
KernelAbstractions.backend(x::ComponentArray) = KernelAbstractions.backend(getdata(x))
7+
8+
@kernel function ca_fill_kernel!(A, @Const(x))
9+
idx = @index(Global, Linear)
10+
@inbounds A[idx] = x
11+
end
12+
13+
function ComponentArrays.fill_componentarray_ka!(A::ComponentArray{T}, x) where {T}
14+
kernel! = ca_fill_kernel!(KernelAbstractions.get_backend(A))
15+
kernel!(A, x; ndrange=length(A))
16+
return A
17+
end
18+
19+
end

src/componentarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ ComponentArray(x::ComponentArray) = x
7878
ComponentArray{T}(x::ComponentArray) where {T} = T.(x)
7979
(CA::Type{<:ComponentArray{T,N,A,Ax}})(x::ComponentArray) where {T,N,A,Ax} = ComponentArray(T.(getdata(x)), getaxes(x))
8080

81+
function fill_componentarray_ka! end # defined in extensions
8182

8283
## Some aliases
8384
"""

0 commit comments

Comments
 (0)