@@ -17,9 +17,22 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
17
17
end
18
18
end
19
19
20
+ mutable struct TracedRScalar{T} <: RScalar{T}
21
+ paths:: Tuple
22
+ mlir_data:: Union{Nothing,MLIR.IR.Value}
23
+
24
+ function TracedRScalar {T} (
25
+ paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
26
+ ) where {T}
27
+ if ! isnothing (mlir_data)
28
+ @assert size (MLIR. IR. type (mlir_data)) == ()
29
+ end
30
+ return new {T} (paths, mlir_data)
31
+ end
32
+ end
33
+
20
34
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
21
35
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
22
- const AnyTracedRScalar{T} = AnyTracedRArray{T,0 }
23
36
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
24
37
const AnyTracedRMatrix{T} = AnyTracedRArray{T,2 }
25
38
const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}}
@@ -38,12 +51,12 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
38
51
return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
39
52
end
40
53
41
- Base. getindex (a:: AnyTracedRScalar {T} ) where {T} = a
54
+ Base. getindex (a:: TracedRScalar {T} ) where {T} = a
42
55
43
- Base. zero (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, zero (T))
44
- Base. one (:: AnyTracedRScalar {T} ) where {T} = promote_to (TracedRArray{T, 0 }, one (T))
56
+ Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, zero (T))
57
+ Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar{T }, one (T))
45
58
46
- function Base. convert (:: Type{<:AnyTracedRScalar {T}} , x:: Number ) where {T}
59
+ function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
47
60
return promote_to (TracedRArray{T,0 }, T (x))
48
61
end
49
62
@@ -71,7 +84,7 @@ and require expensive copies and synchronization each time and therefore should
71
84
),
72
85
1 ,
73
86
)
74
- return TracedRArray {T,0 } ((), res2, () )
87
+ return TracedRScalar {T } ((), res2)
75
88
end
76
89
77
90
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -131,7 +144,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
131
144
# return print(io, X.mlir_data, ")")
132
145
end
133
146
134
- Base. only (A:: AnyTracedRScalar{T} ) where {T} = A
147
+ function Base. show (io:: IOty , X:: TracedRScalar{T} ) where {T,IOty<: Union{IO,IOContext} }
148
+ return print (io, " TracedRScalar{" , T, " }(" , X. paths, " )" )
149
+ end
150
+
151
+ Base. only (A:: TracedRScalar{T} ) where {T} = A
135
152
136
153
function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
137
154
if prod (dims) != prod (size (A))
205
222
206
223
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
207
224
if isa (rhs, TracedRArray)
208
- if typeof (rhs) == TracedRArray{T,N}
209
- return rhs
210
- end
225
+ rhs isa TracedRArray{T,N} && return rhs
211
226
return TracedRArray {T,N} (
212
227
(),
213
228
MLIR. IR. result (
@@ -220,11 +235,8 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
220
235
)
221
236
end
222
237
if isa (rhs, Number)
223
- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRArray{T,N}, size (rhs)))
224
- ta = TracedRArray {T,N} (
225
- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
226
- )
227
- return ta
238
+ throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
239
+ `TracedRScalar` instead." ))
228
240
end
229
241
T0 = eltype (rhs)
230
242
attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -236,9 +248,41 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
236
248
)
237
249
end
238
250
251
+ function promote_to (:: Type{TracedRScalar{T}} , rhs) where {T}
252
+ if isa (rhs, TracedRScalar)
253
+ rhs isa TracedRScalar{T} && return rhs
254
+ return TracedRScalar {T} (
255
+ (),
256
+ MLIR. IR. result (
257
+ MLIR. Dialects. stablehlo. convert (
258
+ rhs. mlir_data; result= mlir_type (TracedRScalar{T})
259
+ ),
260
+ 1 ,
261
+ ),
262
+ )
263
+ end
264
+ if isa (rhs, Number)
265
+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar{T}))
266
+ return TracedRScalar {T} (
267
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
268
+ )
269
+ end
270
+ T0 = eltype (rhs)
271
+ attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
272
+ return promote_to (
273
+ TracedRScalar{T},
274
+ TracedRScalar {T0} (
275
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
276
+ ),
277
+ )
278
+ end
279
+
239
280
function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
240
281
return promote_to (TracedRArray{T,N}, rhs)
241
282
end
283
+ function promote_to (:: TracedRScalar{T} , rhs) where {T}
284
+ return promote_to (TracedRScalar{T}, rhs)
285
+ end
242
286
243
287
for (jlop, hloop) in (
244
288
(:(Base. min), :minimum ),
0 commit comments