@@ -26,11 +26,11 @@ function Base.setproperty!(x::TracedRArray, f::Symbol, v)
26
26
return setfield! (x, f, v)
27
27
end
28
28
29
- mutable struct TracedRScalar {T} <: RScalar {T}
29
+ mutable struct TracedRNumber {T} <: RNumber {T}
30
30
paths:: Tuple
31
31
mlir_data:: Union{Nothing,MLIR.IR.Value}
32
32
33
- function TracedRScalar {T} (
33
+ function TracedRNumber {T} (
34
34
paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
35
35
) where {T}
36
36
if ! isnothing (mlir_data)
@@ -40,14 +40,14 @@ mutable struct TracedRScalar{T} <: RScalar{T}
40
40
end
41
41
end
42
42
43
- function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
43
+ function Base. setproperty! (x:: TracedRNumber , f:: Symbol , v)
44
44
if f === :mlir_data && ! isnothing (v)
45
45
@assert size (MLIR. IR. type (v)) == ()
46
46
end
47
47
return setfield! (x, f, v)
48
48
end
49
49
50
- Base. eltype (:: Type{TracedRScalar {T}} ) where {T} = T
50
+ Base. eltype (:: Type{TracedRNumber {T}} ) where {T} = T
51
51
52
52
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
53
53
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
@@ -69,13 +69,13 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
69
69
return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
70
70
end
71
71
72
- Base. getindex (a:: TracedRScalar {T} ) where {T} = a
72
+ Base. getindex (a:: TracedRNumber {T} ) where {T} = a
73
73
74
- Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar {T}, zero (T))
75
- Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar {T}, one (T))
74
+ Base. zero (:: TracedRNumber {T} ) where {T} = promote_to (TracedRNumber {T}, zero (T))
75
+ Base. one (:: TracedRNumber {T} ) where {T} = promote_to (TracedRNumber {T}, one (T))
76
76
77
- function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
78
- return promote_to (TracedRScalar {T}, T (x))
77
+ function Base. convert (:: Type{<:TracedRNumber {T}} , x:: Number ) where {T}
78
+ return promote_to (TracedRNumber {T}, T (x))
79
79
end
80
80
81
81
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -102,7 +102,7 @@ and require expensive copies and synchronization each time and therefore should
102
102
),
103
103
1 ,
104
104
)
105
- return TracedRScalar {T} ((), res2)
105
+ return TracedRNumber {T} ((), res2)
106
106
end
107
107
108
108
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -137,7 +137,7 @@ function Base.setindex!(
137
137
a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
138
138
) where {T,N}
139
139
indices = [
140
- (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
140
+ (promote_to (TracedRNumber {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
141
141
i in indices
142
142
]
143
143
v = promote_to (TracedRArray{T,N}, v)
@@ -162,11 +162,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
162
162
# return print(io, X.mlir_data, ")")
163
163
end
164
164
165
- function Base. show (io:: IOty , X:: TracedRScalar {T} ) where {T,IOty<: Union{IO,IOContext} }
166
- return print (io, " TracedRScalar {" , T, " }(" , X. paths, " )" )
165
+ function Base. show (io:: IOty , X:: TracedRNumber {T} ) where {T,IOty<: Union{IO,IOContext} }
166
+ return print (io, " TracedRNumber {" , T, " }(" , X. paths, " )" )
167
167
end
168
168
169
- Base. only (A:: TracedRScalar {T} ) where {T} = A
169
+ Base. only (A:: TracedRNumber {T} ) where {T} = A
170
170
171
171
function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
172
172
if prod (dims) != prod (size (A))
@@ -238,12 +238,12 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
238
238
return TracedRArray{Base. promote_type (T, S),N}
239
239
end
240
240
241
- function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar {S}} ) where {T,S}
242
- return TracedRScalar {Base. promote_type (T, S)}
241
+ function Base. promote_rule (:: Type{T} , :: Type{TracedRNumber {S}} ) where {T,S}
242
+ return TracedRNumber {Base. promote_type (T, S)}
243
243
end
244
244
245
- function Base. convert (:: Type{TracedRScalar {T}} , x:: Number ) where {T}
246
- return promote_to (TracedRScalar {T}, x)
245
+ function Base. convert (:: Type{TracedRNumber {T}} , x:: Number ) where {T}
246
+ return promote_to (TracedRNumber {T}, x)
247
247
end
248
248
249
249
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
@@ -262,7 +262,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
262
262
end
263
263
if isa (rhs, Number)
264
264
throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
265
- `TracedRScalar ` instead." ))
265
+ `TracedRNumber ` instead." ))
266
266
end
267
267
T0 = eltype (rhs)
268
268
attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -274,37 +274,37 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
274
274
)
275
275
end
276
276
277
- function promote_to (:: Type{TracedRScalar {T}} , rhs) where {T}
278
- if isa (rhs, TracedRScalar )
279
- rhs isa TracedRScalar {T} && return rhs
280
- return TracedRScalar {T} (
277
+ function promote_to (:: Type{TracedRNumber {T}} , rhs) where {T}
278
+ if isa (rhs, TracedRNumber )
279
+ rhs isa TracedRNumber {T} && return rhs
280
+ return TracedRNumber {T} (
281
281
(),
282
282
MLIR. IR. result (
283
283
MLIR. Dialects. stablehlo. convert (
284
- rhs. mlir_data; result= mlir_type (TracedRScalar {T})
284
+ rhs. mlir_data; result= mlir_type (TracedRNumber {T})
285
285
),
286
286
1 ,
287
287
),
288
288
)
289
289
end
290
290
if isa (rhs, Number)
291
- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar {T}))
292
- return TracedRScalar {T} (
291
+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRNumber {T}))
292
+ return TracedRNumber {T} (
293
293
(), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
294
294
)
295
295
end
296
296
T0 = eltype (rhs)
297
297
attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
298
298
return promote_to (
299
- TracedRScalar {T},
300
- TracedRScalar {T0} (
299
+ TracedRNumber {T},
300
+ TracedRNumber {T0} (
301
301
(), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
302
302
),
303
303
)
304
304
end
305
305
306
306
promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
307
- promote_to (:: TracedRScalar {T} , rhs) where {T} = promote_to (TracedRScalar {T}, rhs)
307
+ promote_to (:: TracedRNumber {T} , rhs) where {T} = promote_to (TracedRNumber {T}, rhs)
308
308
309
309
for (jlop, hloop) in (
310
310
(:(Base. min), :minimum ),
@@ -316,7 +316,7 @@ for (jlop, hloop) in (
316
316
(:(Base.:^ ), :power ),
317
317
)
318
318
@eval function $ (jlop)(
319
- @nospecialize (lhs:: TracedRScalar {T} ), @nospecialize (rhs:: TracedRScalar {T} )
319
+ @nospecialize (lhs:: TracedRNumber {T} ), @nospecialize (rhs:: TracedRNumber {T} )
320
320
) where {T}
321
321
return TracedRArray {T} (
322
322
(),
@@ -328,22 +328,22 @@ for (jlop, hloop) in (
328
328
end
329
329
330
330
function Base. ifelse (
331
- @nospecialize (pred:: TracedRScalar {Bool} ),
332
- @nospecialize (x:: TracedRScalar {T1} ),
333
- @nospecialize (y:: TracedRScalar {T2} )
331
+ @nospecialize (pred:: TracedRNumber {Bool} ),
332
+ @nospecialize (x:: TracedRNumber {T1} ),
333
+ @nospecialize (y:: TracedRNumber {T2} )
334
334
) where {T1,T2}
335
- return TracedRScalar {promote_type(T1, T2)} (
335
+ return TracedRNumber {promote_type(T1, T2)} (
336
336
(),
337
337
MLIR. IR. result (
338
338
MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
339
339
),
340
340
)
341
341
end
342
342
343
- Base. abs2 (x:: Reactant.TracedRScalar {T} ) where {T} = x * conj (x)
343
+ Base. abs2 (x:: Reactant.TracedRNumber {T} ) where {T} = x * conj (x)
344
344
345
345
function Base. literal_pow (
346
- :: Base.RefValue{typeof(^)} , x:: TracedRScalar {T} , :: Base.RefValue{Val{P}}
346
+ :: Base.RefValue{typeof(^)} , x:: TracedRNumber {T} , :: Base.RefValue{Val{P}}
347
347
) where {T,P}
348
348
return Base. literal_pow (^ , x, Val (P))
349
349
end
@@ -360,8 +360,8 @@ for (jlop, hloop) in (
360
360
(:(Base. log), :log ),
361
361
(:(Base. sqrt), :sqrt ),
362
362
)
363
- @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar {T} )) where {T}
364
- return TracedRScalar {T} (
363
+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRNumber {T} )) where {T}
364
+ return TracedRNumber {T} (
365
365
(), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
366
366
)
367
367
end
@@ -467,9 +467,9 @@ for (jlop, hloop, hlocomp, merge) in (
467
467
(:(Base.:(< )), :compare , " LT" , nothing ),
468
468
)
469
469
@eval function $ (jlop)(
470
- @nospecialize (lhs:: TracedRScalar {T} ), @nospecialize (rhs:: TracedRScalar {T} )
470
+ @nospecialize (lhs:: TracedRNumber {T} ), @nospecialize (rhs:: TracedRNumber {T} )
471
471
) where {T}
472
- return TracedRScalar {Bool} (
472
+ return TracedRNumber {Bool} (
473
473
(),
474
474
MLIR. IR. result (
475
475
MLIR. Dialects. stablehlo.$ (hloop)(
@@ -571,7 +571,7 @@ function Base.mapreduce(
571
571
fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
572
572
573
573
args = (
574
- TracedRScalar {T} ((), MLIR. IR. argument (fnbody, i), ()) for
574
+ TracedRNumber {T} ((), MLIR. IR. argument (fnbody, i), ()) for
575
575
(i, ty) in enumerate (in_tys)
576
576
)
577
577
0 commit comments