@@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
761
761
return dest
762
762
end
763
763
764
- function Base . _cat (dims :: Val{D} , A :: TracedRArray{T,N} , Bs :: TracedRArray... ) where {T,N,D}
765
- @assert D isa Integer " Support for non-integer dimensions is not implemented yet. "
764
+ dispatch_val (x) = x
765
+ dispatch_val ( :: Val{D} ) where {D} = D
766
766
767
- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
768
- A = maybe_expand_dims (A, dims)
769
- Bs = maybe_expand_dims .(Bs, (dims,))
767
+ @inline function Base. _typed_vcat (
768
+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
769
+ ) where {T}
770
+ return Base. _cat_t (Val (1 ), T, X... )
771
+ end
772
+ @inline function Base. _typed_hcat (
773
+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
774
+ ) where {T}
775
+ return Base. _cat_t (Val (2 ), T, X... )
776
+ end
777
+
778
+ # `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
779
+ # generic implementation uses `typed_hcat` and `typed_vcat` which is alright
780
+ @inline function Base. typed_hvcat (
781
+ :: Type{T} , rows:: Tuple{Vararg{Int}} , as:: TracedRArray...
782
+ ) where {T}
783
+ return invoke (
784
+ Base. typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
785
+ )
786
+ end
787
+
788
+ function Base. _typed_hvncat (
789
+ T:: Type , dims:: NTuple{N,Int} , row_first:: Bool , as:: TracedRArray...
790
+ ) where {N}
791
+ As = if row_first
792
+ perm = [2 , 1 , 3 : N... ]
793
+ dims = [dims[2 ], dims[1 ], dims[3 : end ]. .. ]
794
+ permutedims (reshape (collect (as), dims... ), perm)
795
+ else
796
+ reshape (collect (as), dims)
797
+ end
798
+
799
+ for d in 1 : N
800
+ Bs = Array {Any,N - d} (undef, size (As)[2 : end ]. .. )
801
+
802
+ for (i, col) in
803
+ zip (eachindex (Bs), eachslice (As; dims= Tuple (2 : ndims (As)), drop= true ))
804
+ # TODO row_first affects the flattening?
805
+ Bs[i] = Base. _cat_t (d, T, col... )
806
+ end
807
+
808
+ As = Bs
809
+ end
810
+
811
+ return only (As)
812
+ end
813
+
814
+ function Base. _cat_t (dims, :: Type{T} , X:: TracedRArray... ) where {T}
815
+ dims = dispatch_val (dims)
816
+ @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
817
+
818
+ # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
819
+ X = maybe_expand_dims .(X, (dims,))
770
820
771
821
catdims = Base. dims2cat (dims)
772
- shape = Base. cat_size_shape (catdims, A, Bs... )
773
- RT = Base. promote_eltype (A, Bs... )
774
- Res = TracedRArray {RT,length(shape)} (
822
+ shape = Base. cat_size_shape (catdims, X... )
823
+ RT = Base. promote_eltype (T, X... )
824
+
825
+ # convert to the target eltype
826
+ X = map (Base. Fix1 (promote_to, TracedRArray{RT,length (shape)}), X)
827
+
828
+ return TracedRArray {RT,length(shape)} (
775
829
(),
776
830
MLIR. IR. result (
831
+ # TODO maybe we should do some conversion?
777
832
MLIR. Dialects. stablehlo. concatenate (
778
- [A . mlir_data, [B . mlir_data for B in Bs] . .. ] ;
833
+ collect ( get_mlir_data .(X)) ;
779
834
result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
780
- dimension= D - 1 , # stablehlo expects this to be zero-indexed
835
+ dimension= dims - 1 , # stablehlo expects this to be zero-indexed
781
836
),
782
837
1 ,
783
838
),
784
839
shape,
785
840
)
786
- return Res
787
841
end
788
842
789
- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
790
- D ≤ N && return x
791
- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
843
+ function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
844
+ dims = dispatch_val (dims)
845
+ dims ≤ N && return x
846
+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , dims))
792
847
end
0 commit comments