Skip to content

Commit 981c93a

Browse files
committed
feat: TracedRScalar
1 parent deefd18 commit 981c93a

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include("OrderedIdDict.jl")
88
using Enzyme
99

1010
abstract type RArray{T,N} <: AbstractArray{T,N} end
11+
abstract type RScalar{T} <: Number end
1112

1213
function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
1314
return reshape(A, Base._reshape_uncolon(A, dims))

src/TracedRArray.jl

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,22 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
1717
end
1818
end
1919

20+
mutable struct TracedRScalar{T} <: RScalar{T}
21+
paths::Tuple
22+
mlir_data::Union{Nothing,MLIR.IR.Value}
23+
24+
function TracedRScalar{T}(
25+
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
26+
) where {T}
27+
if !isnothing(mlir_data)
28+
@assert size(MLIR.IR.type(mlir_data)) == ()
29+
end
30+
return new{T}(paths, mlir_data)
31+
end
32+
end
33+
2034
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
2135
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
22-
const AnyTracedRScalar{T} = AnyTracedRArray{T,0}
2336
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
2437
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2}
2538
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -38,12 +51,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
3851
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
3952
end
4053

41-
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
54+
Base.getindex(a::TracedRScalar{T}) where {T} = a
4255

43-
Base.zero(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, zero(T))
44-
Base.one(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T,0}, one(T))
56+
Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
57+
Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T))
4558

46-
function Base.convert(::Type{<:AnyTracedRScalar{T}}, x::Number) where {T}
59+
function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T}
4760
return promote_to(TracedRArray{T,0}, T(x))
4861
end
4962

@@ -71,7 +84,7 @@ and require expensive copies and synchronization each time and therefore should
7184
),
7285
1,
7386
)
74-
return TracedRArray{T,0}((), res2, ())
87+
return TracedRScalar{T}((), res2)
7588
end
7689

7790
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
@@ -131,7 +144,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
131144
# return print(io, X.mlir_data, ")")
132145
end
133146

134-
Base.only(A::AnyTracedRScalar{T}) where {T} = A
147+
function Base.show(io::IOty, X::TracedRScalar{T}) where {T,IOty<:Union{IO,IOContext}}
148+
return print(io, "TracedRScalar{", T, "}(", X.paths, ")")
149+
end
150+
151+
Base.only(A::TracedRScalar{T}) where {T} = A
135152

136153
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
137154
if prod(dims) != prod(size(A))
@@ -205,9 +222,7 @@ end
205222

206223
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
207224
if isa(rhs, TracedRArray)
208-
if typeof(rhs) == TracedRArray{T,N}
209-
return rhs
210-
end
225+
rhs isa TracedRArray{T,N} && return rhs
211226
return TracedRArray{T,N}(
212227
(),
213228
MLIR.IR.result(
@@ -220,11 +235,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
220235
)
221236
end
222237
if isa(rhs, Number)
223-
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRArray{T,N}, size(rhs)))
224-
ta = TracedRArray{T,N}(
225-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
226-
)
227-
return ta
238+
throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \
239+
`TracedRScalar` instead."))
228240
end
229241
T0 = eltype(rhs)
230242
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
@@ -236,9 +248,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
236248
)
237249
end
238250

251+
function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
252+
if isa(rhs, TracedRScalar)
253+
rhs isa TracedRScalar{T} && return rhs
254+
return TracedRScalar{T}(
255+
(),
256+
MLIR.IR.result(
257+
MLIR.Dialects.stablehlo.convert(
258+
rhs.mlir_data; result=mlir_type(TracedRScalar{T})
259+
),
260+
1,
261+
),
262+
)
263+
end
264+
if isa(rhs, Number)
265+
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRScalar{T}))
266+
return TracedRScalar{T}(
267+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
268+
)
269+
end
270+
T0 = eltype(rhs)
271+
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
272+
return promote_to(
273+
TracedRScalar{T},
274+
TracedRScalar{T0}(
275+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
276+
),
277+
)
278+
end
279+
239280
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
240281
return promote_to(TracedRArray{T,N}, rhs)
241282
end
283+
function promote_to(::TracedRScalar{T}, rhs) where {T}
284+
return promote_to(TracedRScalar{T}, rhs)
285+
end
242286

243287
for (jlop, hloop) in (
244288
(:(Base.min), :minimum),

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ function mlir_type(x::RArray{T,N}) where {T,N}
22
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
33
end
44

5+
mlir_type(::RScalar{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T))
6+
57
function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N}
68
@assert length(shape) == N
79
return MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
810
end
911

12+
function mlir_type(::Type{<:RScalar{T}}) where {T}
13+
return MLIR.IR.TensorType((), MLIR.IR.Type(T))
14+
end
15+
1016
function transpose_ty(mlirty)
1117
return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty))
1218
end

0 commit comments

Comments
 (0)