Skip to content

Commit 1282ce2

Browse files
committed
fix: special handling for concatenation of numbers
1 parent 60b614b commit 1282ce2

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/TracedRNumber.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,13 @@ struct TypeCast{T<:ReactantPrimitives} <: Function end
209209
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)
210210

211211
Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x)
212+
213+
# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
214+
Base.vcat(x::TracedRNumber...) = vcat(map(Base.Fix2(broadcast_to_size, (1,)), x)...)
215+
function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T}
216+
return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...)
217+
end
218+
Base.hcat(x::TracedRNumber...) = hcat(map(Base.Fix2(broadcast_to_size, (1, 1)), x)...)
219+
function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T}
220+
return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...)
221+
end

0 commit comments

Comments
 (0)