19
19
20
20
TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
21
21
22
- mutable struct TracedRNumber{T} <: RNumber{T}
23
- paths:: Tuple
24
- mlir_data:: Union{Nothing,MLIR.IR.Value}
25
-
26
- function TracedRNumber {T} (
27
- paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
28
- ) where {T}
29
- if ! isnothing (mlir_data)
30
- @assert size (MLIR. IR. type (mlir_data)) == ()
31
- end
32
- return new {T} (paths, mlir_data)
33
- end
34
- end
35
-
36
- Base. eltype (:: Type{TracedRNumber{T}} ) where {T} = T
37
-
38
22
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
39
23
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
40
24
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -55,15 +39,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
55
39
return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
56
40
end
57
41
58
- Base. getindex (a:: TracedRNumber{T} ) where {T} = a
59
-
60
- Base. zero (:: TracedRNumber{T} ) where {T} = promote_to (TracedRNumber{T}, zero (T))
61
- Base. one (:: TracedRNumber{T} ) where {T} = promote_to (TracedRNumber{T}, one (T))
62
-
63
- function Base. convert (:: Type{<:TracedRNumber{T}} , x:: Number ) where {T}
64
- return promote_to (TracedRNumber{T}, T (x))
65
- end
66
-
67
42
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
68
43
@warn (
69
44
""" Performing scalar indexing on task $(current_task ()) .
@@ -148,12 +123,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
148
123
# return print(io, X.mlir_data, ")")
149
124
end
150
125
151
- function Base. show (io:: IOty , X:: TracedRNumber{T} ) where {T,IOty<: Union{IO,IOContext} }
152
- return print (io, " TracedRNumber{" , T, " }(" , X. paths, " )" )
153
- end
154
-
155
- Base. only (A:: TracedRNumber{T} ) where {T} = A
156
-
157
126
function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
158
127
if prod (dims) != prod (size (A))
159
128
throw (
@@ -214,18 +183,6 @@ function Base.transpose(A::AnyTracedRVecOrMat)
214
183
end
215
184
Base. adjoint (A:: AnyTracedRVecOrMat{<:Real} ) = transpose (A)
216
185
217
- function Base. promote_rule (:: Type{TracedRNumber{T}} , :: Type{TracedRNumber{S}} ) where {T,S}
218
- return TracedRNumber{Base. promote_type (T, S)}
219
- end
220
-
221
- function Base. promote_rule (:: Type{T} , :: Type{TracedRNumber{S}} ) where {T,S}
222
- return TracedRNumber{Base. promote_type (T, S)}
223
- end
224
-
225
- function Base. convert (:: Type{TracedRNumber{T}} , x:: Number ) where {T}
226
- return promote_to (TracedRNumber{T}, x)
227
- end
228
-
229
186
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
230
187
if isa (rhs, TracedRArray)
231
188
rhs isa TracedRArray{T,N} && return rhs
@@ -254,103 +211,10 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
254
211
)
255
212
end
256
213
257
- function promote_to (:: Type{TracedRNumber{T}} , rhs) where {T}
258
- if isa (rhs, TracedRNumber)
259
- rhs isa TracedRNumber{T} && return rhs
260
- return TracedRNumber {T} (
261
- (),
262
- MLIR. IR. result (
263
- MLIR. Dialects. stablehlo. convert (
264
- rhs. mlir_data; result= mlir_type (TracedRNumber{T})
265
- ),
266
- 1 ,
267
- ),
268
- )
269
- end
270
- if isa (rhs, Number)
271
- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRNumber{T}))
272
- return TracedRNumber {T} (
273
- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
274
- )
275
- end
276
- T0 = eltype (rhs)
277
- attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
278
- return promote_to (
279
- TracedRNumber{T},
280
- TracedRNumber {T0} (
281
- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
282
- ),
283
- )
284
- end
285
-
286
214
promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
287
- promote_to (:: TracedRNumber{T} , rhs) where {T} = promote_to (TracedRNumber{T}, rhs)
288
-
289
- for (jlop, hloop) in (
290
- (:(Base. min), :minimum ),
291
- (:(Base. max), :maximum ),
292
- (:(Base.:+ ), :add ),
293
- (:(Base.:- ), :subtract ),
294
- (:(Base.:* ), :multiply ),
295
- (:(Base.:/ ), :divide ),
296
- (:(Base.:^ ), :power ),
297
- )
298
- @eval function $ (jlop)(
299
- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
300
- ) where {T}
301
- return TracedRNumber {T} (
302
- (),
303
- MLIR. IR. result (
304
- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
305
- ),
306
- )
307
- end
308
- end
309
-
310
- function Base. ifelse (
311
- @nospecialize (pred:: TracedRNumber{Bool} ),
312
- @nospecialize (x:: TracedRNumber{T1} ),
313
- @nospecialize (y:: TracedRNumber{T2} )
314
- ) where {T1,T2}
315
- return TracedRNumber {promote_type(T1, T2)} (
316
- (),
317
- MLIR. IR. result (
318
- MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
319
- ),
320
- )
321
- end
322
-
323
- function Base. literal_pow (
324
- :: Base.RefValue{typeof(^)} , x:: TracedRNumber{T} , :: Base.RefValue{Val{P}}
325
- ) where {T,P}
326
- return Base. literal_pow (^ , x, Val (P))
327
- end
328
-
329
- for (jlop, hloop) in (
330
- (:(Base. abs), :abs ),
331
- (:(Base.:- ), :negate ),
332
- (:(Base. sin), :sine ),
333
- (:(Base. cos), :cosine ),
334
- (:(Base. tanh), :tanh ),
335
- (:(Base. FastMath. tanh_fast), :tanh ),
336
- (:(Base. exp), :exponential ),
337
- (:(Base. FastMath. exp_fast), :exponential ),
338
- (:(Base. log), :log ),
339
- (:(Base. sqrt), :sqrt ),
340
- )
341
- @eval function $ (jlop)(@nospecialize (lhs:: TracedRNumber{T} )) where {T}
342
- return TracedRNumber {T} (
343
- (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
344
- )
345
- end
346
- end
347
215
348
216
struct TypeCast{T<: Number } <: Function end
349
217
350
- function (:: TypeCast{T} )(x:: TracedRNumber{T2} ) where {T,T2}
351
- return promote_to (TracedRNumber{T}, x)
352
- end
353
-
354
218
elem_apply (:: Type{T} , x:: TracedRArray{T} ) where {T<: Number } = x
355
219
function elem_apply (:: Type{T} , x:: TracedRArray{T2} ) where {T<: Number ,T2<: Number }
356
220
# Special Path to prevent going down a despecialized path
@@ -435,41 +299,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
435
299
return traced2_result
436
300
end
437
301
438
- for (jlop, hloop, hlocomp, merge) in (
439
- (:(Base.:(== )), :compare , " EQ" , :all ),
440
- (:(Base.:(!= )), :compare , " NE" , :any ),
441
- (:(Base.:(>= )), :compare , " GE" , nothing ),
442
- (:(Base.:(> )), :compare , " GT" , nothing ),
443
- (:(Base.:(<= )), :compare , " LE" , nothing ),
444
- (:(Base.:(< )), :compare , " LT" , nothing ),
445
- )
446
- @eval function $ (jlop)(
447
- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
448
- ) where {T}
449
- return TracedRNumber {Bool} (
450
- (),
451
- MLIR. IR. result (
452
- MLIR. Dialects. stablehlo.$ (hloop)(
453
- lhs. mlir_data,
454
- rhs. mlir_data;
455
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
456
- MLIR. IR. context (), $ hlocomp
457
- ),
458
- ),
459
- 1 ,
460
- ),
461
- )
462
- end
463
-
464
- if merge != = nothing
465
- @eval begin
466
- function $jlop (
467
- @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
468
- ) where {T,N}
469
- elems = $ (jlop). (lhs, rhs)
470
- return N == 0 ? elems : $ (merge)(elems)
471
- end
472
- end
302
+ for (jlop, hloop, hlocomp, merge) in
303
+ ((:(Base.:(== )), :compare , " EQ" , :all ), (:(Base.:(!= )), :compare , " NE" , :any ))
304
+ @eval function $jlop (
305
+ @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
306
+ ) where {T,N}
307
+ elems = $ (jlop). (lhs, rhs)
308
+ return N == 0 ? elems : $ (merge)(elems)
473
309
end
474
310
end
475
311
0 commit comments