19
19
20
20
TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
21
21
22
+ function Base. setproperty! (x:: TracedRArray , f:: Symbol , v)
23
+ if f === :mlir_data && ! isnothing (v)
24
+ @assert size (MLIR. IR. type (v)) == size (x)
25
+ end
26
+ return setfield! (x, f, v)
27
+ end
28
+
22
29
mutable struct TracedRScalar{T} <: RScalar{T}
23
30
paths:: Tuple
24
31
mlir_data:: Union{Nothing,MLIR.IR.Value}
@@ -33,6 +40,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
33
40
end
34
41
end
35
42
43
+ function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
44
+ if f === :mlir_data && ! isnothing (v)
45
+ @assert size (MLIR. IR. type (v)) == ()
46
+ end
47
+ return setfield! (x, f, v)
48
+ end
49
+
50
+ Base. eltype (:: Type{TracedRScalar{T}} ) where {T} = T
51
+
36
52
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
37
53
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
38
54
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -59,7 +75,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
59
75
Base. one (:: TracedRScalar{T} ) where {T} = promote_to (TracedRScalar{T}, one (T))
60
76
61
77
function Base. convert (:: Type{<:TracedRScalar{T}} , x:: Number ) where {T}
62
- return promote_to (TracedRArray{T, 0 }, T (x))
78
+ return promote_to (TracedRScalar{T }, T (x))
63
79
end
64
80
65
81
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -121,7 +137,7 @@ function Base.setindex!(
121
137
a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
122
138
) where {T,N}
123
139
indices = [
124
- (promote_to (TracedRArray {Int, 0 }, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
140
+ (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
125
141
i in indices
126
142
]
127
143
v = promote_to (TracedRArray{T,N}, v)
@@ -222,6 +238,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
222
238
return TracedRArray{Base. promote_type (T, S),N}
223
239
end
224
240
241
+ function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar{S}} ) where {T,S}
242
+ return TracedRScalar{Base. promote_type (T, S)}
243
+ end
244
+
245
+ function Base. convert (:: Type{TracedRScalar{T}} , x:: Number ) where {T}
246
+ return promote_to (TracedRScalar{T}, x)
247
+ end
248
+
225
249
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
226
250
if isa (rhs, TracedRArray)
227
251
rhs isa TracedRArray{T,N} && return rhs
@@ -279,12 +303,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
279
303
)
280
304
end
281
305
282
- function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
283
- return promote_to (TracedRArray{T,N}, rhs)
284
- end
285
- function promote_to (:: TracedRScalar{T} , rhs) where {T}
286
- return promote_to (TracedRScalar{T}, rhs)
287
- end
306
+ promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
307
+ promote_to (:: TracedRScalar{T} , rhs) where {T} = promote_to (TracedRScalar{T}, rhs)
288
308
289
309
for (jlop, hloop) in (
290
310
(:(Base. min), :minimum ),
@@ -295,66 +315,35 @@ for (jlop, hloop) in (
295
315
(:(Base.:/ ), :divide ),
296
316
(:(Base.:^ ), :power ),
297
317
)
298
- @eval begin
299
- function $ (jlop)(
300
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
301
- ) where {T}
302
- return TracedRArray {T,0} (
303
- (),
304
- MLIR. IR. result (
305
- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
306
- ),
307
- (),
308
- )
309
- end
310
-
311
- function $ (jlop)(
312
- @nospecialize (lhs:: TracedRArray{T1,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
313
- ) where {T1,T2}
314
- commonTy = TracedRArray{Base. promote_type (T1, T2),0 }
315
- lhs = promote_to (commonTy, lhs)
316
- rhs = promote_to (commonTy, rhs)
317
- return $ (jlop)(lhs, rhs)
318
- end
319
- end
320
-
321
- for otherType in (Number, Any)
322
- @eval begin
323
- function $ (jlop)(
324
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: $ (otherType))
325
- ) where {T}
326
- rhs = promote_to (lhs, rhs)
327
- return $ (jlop)(lhs, rhs)
328
- end
329
-
330
- function $ (jlop)(
331
- @nospecialize (lhs:: $ (otherType)), @nospecialize (rhs:: TracedRArray{T,0} )
332
- ) where {T}
333
- lhs = promote_to (rhs, lhs)
334
- return $ (jlop)(lhs, rhs)
335
- end
336
- end
318
+ @eval function $ (jlop)(
319
+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
320
+ ) where {T}
321
+ return TracedRArray {T} (
322
+ (),
323
+ MLIR. IR. result (
324
+ MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
325
+ ),
326
+ )
337
327
end
338
328
end
339
329
340
330
function Base. ifelse (
341
- @nospecialize (pred:: TracedRArray {Bool,0 } ),
342
- @nospecialize (x:: TracedRArray {T1,0 } ),
343
- @nospecialize (y:: TracedRArray {T2,0 } )
331
+ @nospecialize (pred:: TracedRScalar {Bool} ),
332
+ @nospecialize (x:: TracedRScalar {T1} ),
333
+ @nospecialize (y:: TracedRScalar {T2} )
344
334
) where {T1,T2}
345
- return TracedRArray {promote_type(T1, T2),0 } (
335
+ return TracedRScalar {promote_type(T1, T2)} (
346
336
(),
347
337
MLIR. IR. result (
348
338
MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
349
339
),
350
- size (pred),
351
340
)
352
341
end
353
342
354
- Base. abs2 (x:: Reactant.TracedRArray{T,0 } ) where {T} = x * conj (x)
343
+ Base. abs2 (x:: Reactant.TracedRScalar{T } ) where {T} = x * conj (x)
355
344
356
345
function Base. literal_pow (
357
- :: Base.RefValue{typeof(^)} , x:: TracedRArray{T,0 } , :: Base.RefValue{Val{P}}
346
+ :: Base.RefValue{typeof(^)} , x:: TracedRScalar{T } , :: Base.RefValue{Val{P}}
358
347
) where {T,P}
359
348
return Base. literal_pow (^ , x, Val (P))
360
349
end
@@ -371,14 +360,10 @@ for (jlop, hloop) in (
371
360
(:(Base. log), :log ),
372
361
(:(Base. sqrt), :sqrt ),
373
362
)
374
- @eval begin
375
- function $jlop (@nospecialize (lhs:: TracedRArray{T,0} )) where {T}
376
- return TracedRArray {T,0} (
377
- (),
378
- MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 ),
379
- size (lhs),
380
- )
381
- end
363
+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar{T} )) where {T}
364
+ return TracedRScalar {T} (
365
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
366
+ )
382
367
end
383
368
end
384
369
@@ -445,6 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
445
430
residx = 1
446
431
447
432
for a in linear_results
433
+ @show a
448
434
if has_residx (a)
449
435
path = get_residx (a)
450
436
set! (result, path[2 : end ], MLIR. IR. result (res, residx))
@@ -480,37 +466,22 @@ for (jlop, hloop, hlocomp, merge) in (
480
466
(:(Base.:(<= )), :compare , " LE" , nothing ),
481
467
(:(Base.:(< )), :compare , " LT" , nothing ),
482
468
)
483
- @eval begin
484
- function $ (jlop)(
485
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
486
- ) where {T}
487
- return TracedRArray {Bool,0} (
488
- (),
489
- MLIR. IR. result (
490
- MLIR. Dialects. stablehlo.$ hloop (
491
- lhs. mlir_data,
492
- rhs. mlir_data;
493
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
494
- MLIR. IR. context (), $ hlocomp
495
- ),
469
+ @eval function $ (jlop)(
470
+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
471
+ ) where {T}
472
+ return TracedRScalar {Bool} (
473
+ (),
474
+ MLIR. IR. result (
475
+ MLIR. Dialects. stablehlo.$ (hloop)(
476
+ lhs. mlir_data,
477
+ rhs. mlir_data;
478
+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
479
+ MLIR. IR. context (), $ hlocomp
496
480
),
497
- 1 ,
498
481
),
499
- size (lhs),
500
- )
501
- end
502
-
503
- function $ (jlop)(
504
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)
505
- ) where {T}
506
- return $ (jlop)(lhs, promote_to (lhs, rhs))
507
- end
508
-
509
- function $ (jlop)(
510
- @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0} )
511
- ) where {T}
512
- return $ (jlop)(promote_to (rhs, lhs), rhs)
513
- end
482
+ 1 ,
483
+ ),
484
+ )
514
485
end
515
486
516
487
if merge != = nothing
@@ -600,7 +571,7 @@ function Base.mapreduce(
600
571
fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
601
572
602
573
args = (
603
- TracedRArray {T,0 } ((), MLIR. IR. argument (fnbody, i), ()) for
574
+ TracedRScalar {T } ((), MLIR. IR. argument (fnbody, i), ()) for
604
575
(i, ty) in enumerate (in_tys)
605
576
)
606
577
0 commit comments