Skip to content

Commit dd0e4c0

Browse files
Merge pull request #292 from SciML/ap/reactant
feat: reactant support
2 parents 105eeaf + 21890c3 commit dd0e4c0

File tree

7 files changed

+51
-7
lines changed

7 files changed

+51
-7
lines changed

Project.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ComponentArrays"
22
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
33
authors = ["Jonnie Diegelman <[email protected]>"]
4-
version = "0.15.20"
4+
version = "0.15.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -18,6 +18,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1818
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1919
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2020
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
21+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2122
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2223
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2324
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2829
ComponentArraysGPUArraysExt = "GPUArrays"
2930
ComponentArraysKernelAbstractionsExt = "KernelAbstractions"
3031
ComponentArraysOptimisersExt = "Optimisers"
32+
ComponentArraysReactantExt = "Reactant"
3133
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
3234
ComponentArraysReverseDiffExt = "ReverseDiff"
3335
ComponentArraysSciMLBaseExt = "SciMLBase"
@@ -36,20 +38,21 @@ ComponentArraysZygoteExt = "Zygote"
3638

3739
[compat]
3840
Adapt = "4.1"
39-
ArrayInterface = "7.10"
40-
ChainRulesCore = "1.24"
41+
ArrayInterface = "7.17.1"
42+
ChainRulesCore = "1.25"
4143
ConstructionBase = "1"
4244
ForwardDiff = "0.10.36"
4345
Functors = "0.4.12, 0.5"
44-
GPUArrays = "10, 11"
46+
GPUArrays = "10.3.1, 11"
4547
KernelAbstractions = "0.9.29"
4648
LinearAlgebra = "1.10"
4749
Optimisers = "0.3, 0.4"
50+
Reactant = "0.2.15"
4851
RecursiveArrayTools = "3.8"
4952
ReverseDiff = "1.15"
5053
SciMLBase = "2"
5154
StaticArrayInterface = "1"
5255
StaticArraysCore = "1.4"
53-
Tracker = "0.2.34"
56+
Tracker = "0.2.37"
5457
Zygote = "0.6.70, 0.7"
5558
julia = "1.10"

ext/ComponentArraysReactantExt.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module ComponentArraysReactantExt
2+
3+
using ArrayInterface: ArrayInterface
4+
using ComponentArrays, Reactant
5+
6+
const TracedComponentVector{T} = ComponentVector{
7+
Reactant.TracedRNumber{T},<:Reactant.TracedRArray{T}
8+
} where {T}
9+
10+
# Reactant is good at memory management but not great at handling wrapped types. So we avoid
11+
# wrapping types into SubArrays and let Reactant optimize out intermediate allocations.
12+
13+
@inline function Base.getproperty(x::TracedComponentVector{T}, s::Symbol) where {T}
14+
return getproperty(x, Val(s))
15+
end
16+
17+
@inline function Base.getproperty(x::TracedComponentVector{T}, v::Val) where {T}
18+
return ComponentArrays._getindex(Base.getindex, x, v)
19+
end
20+
21+
function ArrayInterface.restructure(x::ComponentVector, y::TracedComponentVector)
22+
getaxes(x) == getaxes(y) || error("Axes must match")
23+
return y
24+
end
25+
26+
end

src/componentarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ end
6666
Adapt.adapt_storage(::Type{ComponentArray{T,N,A,Ax}}, xs::AT) where {T,N,A,Ax,AT<:AbstractArray} =
6767
Adapt.adapt_storage(A, xs)
6868

69+
Adapt.parent_type(::Type{ComponentArray{T,N,A,Ax}}) where {T,N,A,Ax} = A
70+
6971
# Entry from NamedTuple, Dict, or kwargs
7072
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
7173
ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1313
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
14+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1415
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1516
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/gpu_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using JLArrays
1+
using JLArrays, LinearAlgebra
22

33
JLArrays.allowscalar(false)
44

@@ -11,7 +11,7 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4))
1111
@test getdata(map(identity, jlca)) isa JLArray
1212
@test all(==(0), map(-, jlca, jla))
1313
@test all(map(-, jlca, jlca) .== 0)
14-
@test all(==(0), map(-, jla, jlca))
14+
@test all(==(0), map(-, jla, jlca)) broken=(pkgversion(JLArrays.GPUArrays) v"11")
1515

1616
@test any(==(1), jlca)
1717
@test count(>(2), jlca) == 2

test/reactant_tests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using Reactant, ComponentArrays
2+
3+
x = ComponentArray(; a = rand(4), b = rand(2))
4+
x_ra = Reactant.to_rarray(x)
5+
6+
fn(x) = x.a .+ sum(abs2, x.b) .+ 1
7+
8+
@test @jit(fn(x_ra)) fn(x)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,3 +732,7 @@ end
732732
@testset "GPU" begin
733733
include("gpu_tests.jl")
734734
end
735+
736+
@testset "Reactant" begin
737+
include("reactant_tests.jl")
738+
end

0 commit comments

Comments
 (0)