Skip to content

Commit f2c0e8a

Browse files
mofeingavik-palgithub-actions[bot]
authored
Generalize Base._cat to non-Val, typed Base._cat_t and implement typed_hcat, typed_vcat, typed_hvcat, typed_hvncat (#163)
* Remove `Val` constraint on `Base._cat` signature * Remove `Val` constraint on `maybe_expand_dims` * fix: update src/TracedRArray.jl * Generalize `Base._cat` implementation on `TracedRArray` to typed `Base._cat_t` * Update src/TracedRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix collection type passed to `stablehlo.concatenate` * Test `cat` methods * Test result eltype on `*cat` methods * Fix conversion of integer arrays to `ConcreteRArray`s * Format code Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix `_typed_cat`, `_typed_hcat`, `typed_hvcat` dispatches * Fix `hvcat` * Convert to target eltype before cat * Fix `typed_hcat` tests * Test `typed_hvncat` on vectors * Refactor tests * Add more test cases * Refactor tests * Fix typo --------- Co-authored-by: Avik Pal <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 9904590 commit f2c0e8a

File tree

3 files changed

+137
-29
lines changed

3 files changed

+137
-29
lines changed

src/TracedRArray.jl

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
761761
return dest
762762
end
763763

764-
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
765-
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
764+
dispatch_val(x) = x
765+
dispatch_val(::Val{D}) where {D} = D
766766

767-
# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
768-
A = maybe_expand_dims(A, dims)
769-
Bs = maybe_expand_dims.(Bs, (dims,))
767+
@inline function Base._typed_vcat(
768+
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
769+
) where {T}
770+
return Base._cat_t(Val(1), T, X...)
771+
end
772+
@inline function Base._typed_hcat(
773+
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
774+
) where {T}
775+
return Base._cat_t(Val(2), T, X...)
776+
end
777+
778+
# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
779+
# generic implementation uses `typed_hcat` and `typed_vcat` which is alright
780+
@inline function Base.typed_hvcat(
781+
::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray...
782+
) where {T}
783+
return invoke(
784+
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
785+
)
786+
end
787+
788+
function Base._typed_hvncat(
789+
T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray...
790+
) where {N}
791+
As = if row_first
792+
perm = [2, 1, 3:N...]
793+
dims = [dims[2], dims[1], dims[3:end]...]
794+
permutedims(reshape(collect(as), dims...), perm)
795+
else
796+
reshape(collect(as), dims)
797+
end
798+
799+
for d in 1:N
800+
Bs = Array{Any,N - d}(undef, size(As)[2:end]...)
801+
802+
for (i, col) in
803+
zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true))
804+
# TODO row_first affects the flattening?
805+
Bs[i] = Base._cat_t(d, T, col...)
806+
end
807+
808+
As = Bs
809+
end
810+
811+
return only(As)
812+
end
813+
814+
function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
815+
dims = dispatch_val(dims)
816+
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."
817+
818+
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
819+
X = maybe_expand_dims.(X, (dims,))
770820

771821
catdims = Base.dims2cat(dims)
772-
shape = Base.cat_size_shape(catdims, A, Bs...)
773-
RT = Base.promote_eltype(A, Bs...)
774-
Res = TracedRArray{RT,length(shape)}(
822+
shape = Base.cat_size_shape(catdims, X...)
823+
RT = Base.promote_eltype(T, X...)
824+
825+
# convert to the target eltype
826+
X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X)
827+
828+
return TracedRArray{RT,length(shape)}(
775829
(),
776830
MLIR.IR.result(
831+
# TODO maybe we should do some conversion?
777832
MLIR.Dialects.stablehlo.concatenate(
778-
[A.mlir_data, [B.mlir_data for B in Bs]...];
833+
collect(get_mlir_data.(X));
779834
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
780-
dimension=D - 1, # stablehlo expects this to be zero-indexed
835+
dimension=dims - 1, # stablehlo expects this to be zero-indexed
781836
),
782837
1,
783838
),
784839
shape,
785840
)
786-
return Res
787841
end
788842

789-
function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
790-
D N && return x
791-
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, Val(D)))
843+
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
844+
dims = dispatch_val(dims)
845+
dims N && return x
846+
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, dims))
792847
end

src/Tracing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ function make_tracer(
380380
if haskey(seen, prev)
381381
return seen[prev]
382382
end
383-
if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat
383+
if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer}
384384
return seen[prev] = ConcreteRArray(prev)
385385
end
386386
TT = traced_type(eltype(RT), (), Val(mode))

test/basic.jl

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,20 +210,73 @@ end
210210
end
211211

212212
@testset "concatenation" begin
213-
x = ones(2, 4, 3)
214-
x_concrete = Reactant.to_rarray(x)
215-
216-
cat1(x) = vcat(x, x, x)
217-
cat2(x) = hcat(x, x, x)
218-
cat3(x) = cat(x, x, x; dims=Val(3))
219-
220-
cat1_compiled = @compile cat1(x_concrete)
221-
cat2_compiled = @compile cat2(x_concrete)
222-
cat3_compiled = @compile cat3(x_concrete)
223-
224-
@test cat1(x) cat1_compiled(x_concrete)
225-
@test cat2(x) cat2_compiled(x_concrete)
226-
@test cat3(x) cat3_compiled(x_concrete)
213+
@testset "$(ndims(x))-dim" for x in [
214+
fill(true),
215+
[true, false],
216+
[true false],
217+
[true true; true false],
218+
[
219+
true true true true; true true true false;;;
220+
true true false true; true true false false;;;
221+
true false true true; true false true false
222+
],
223+
]
224+
x_concrete = Reactant.to_rarray(x)
225+
226+
# NOTE [,,,] is a call to `vect`, not `*cat`
227+
# f = Reactant.compile((x_concrete,)) do x
228+
# return [x, x, x]
229+
# end
230+
# @test f(x_concrete) ≈ ones(3)
231+
232+
# vcat
233+
test_vcat(x) = [x; x; x]
234+
f = @compile test_vcat(x_concrete)
235+
@test f(x_concrete) == test_vcat(x)
236+
@test eltype(f(x_concrete)) === Bool
237+
238+
# hcat
239+
test_hcat(x) = [x x x]
240+
f = @compile test_hcat(x_concrete)
241+
@test f(x_concrete) == test_hcat(x)
242+
@test eltype(f(x_concrete)) === Bool
243+
244+
# hvcat
245+
test_hvcat(x) = [x x x; x x x]
246+
f = @compile test_hvcat(x_concrete)
247+
@test f(x_concrete) == test_hvcat(x)
248+
@test eltype(f(x_concrete)) === Bool
249+
250+
# hvncat
251+
test_hvncat(x) = [x x x; x x x;;; x x x; x x x]
252+
f = @compile test_hvncat(x_concrete)
253+
@test f(x_concrete) == test_hvncat(x)
254+
@test eltype(f(x_concrete)) === Bool
255+
256+
# typed_vcat
257+
test_typed_vcat(x) = Int[x; x; x]
258+
f = @compile test_typed_vcat(x_concrete)
259+
@test f(x_concrete) == test_typed_vcat(x)
260+
@test eltype(f(x_concrete)) === Int
261+
262+
# typed_hcat
263+
test_typed_hcat(x) = Int[x x x]
264+
f = @compile test_typed_hcat(x_concrete)
265+
@test f(x_concrete) == test_typed_hcat(x)
266+
@test eltype(f(x_concrete)) === Int
267+
268+
# typed_hvcat
269+
test_typed_hvcat(x) = Int[x x x; x x x]
270+
f = @compile test_typed_hvcat(x_concrete)
271+
@test f(x_concrete) == test_typed_hvcat(x)
272+
@test eltype(f(x_concrete)) === Int
273+
274+
# typed_hvncat
275+
test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x]
276+
f = @compile test_typed_hvncat(x_concrete)
277+
@test f(x_concrete) == test_typed_hvncat(x)
278+
@test eltype(f(x_concrete)) === Int
279+
end
227280
end
228281

229282
function update_on_copy(x)

0 commit comments

Comments
 (0)