Skip to content

Commit 4709537

Browse files
committed
refactor: move scalar ops to a separate file
1 parent 240eccf commit 4709537

File tree

3 files changed

+169
-171
lines changed

3 files changed

+169
-171
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ include("Interpreter.jl")
6666
include("utils.jl")
6767
include("ConcreteRArray.jl")
6868
include("TracedRArray.jl")
69+
include("TracedRNumber.jl")
6970
include("Tracing.jl")
7071
include("Compiler.jl")
7172

src/TracedRArray.jl

Lines changed: 7 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,6 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22-
mutable struct TracedRNumber{T} <: RNumber{T}
23-
paths::Tuple
24-
mlir_data::Union{Nothing,MLIR.IR.Value}
25-
26-
function TracedRNumber{T}(
27-
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
28-
) where {T}
29-
if !isnothing(mlir_data)
30-
@assert size(MLIR.IR.type(mlir_data)) == ()
31-
end
32-
return new{T}(paths, mlir_data)
33-
end
34-
end
35-
36-
Base.eltype(::Type{TracedRNumber{T}}) where {T} = T
37-
3822
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3923
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
4024
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
@@ -55,15 +39,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
5539
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
5640
end
5741

58-
Base.getindex(a::TracedRNumber{T}) where {T} = a
59-
60-
Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T))
61-
Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T))
62-
63-
function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T}
64-
return promote_to(TracedRNumber{T}, T(x))
65-
end
66-
6742
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
6843
@warn(
6944
"""Performing scalar indexing on task $(current_task()).
@@ -148,12 +123,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
148123
# return print(io, X.mlir_data, ")")
149124
end
150125

151-
function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
152-
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
153-
end
154-
155-
Base.only(A::TracedRNumber{T}) where {T} = A
156-
157126
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
158127
if prod(dims) != prod(size(A))
159128
throw(
@@ -214,18 +183,6 @@ function Base.transpose(A::AnyTracedRVecOrMat)
214183
end
215184
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
216185

217-
function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S}
218-
return TracedRNumber{Base.promote_type(T, S)}
219-
end
220-
221-
function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
222-
return TracedRNumber{Base.promote_type(T, S)}
223-
end
224-
225-
function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
226-
return promote_to(TracedRNumber{T}, x)
227-
end
228-
229186
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
230187
if isa(rhs, TracedRArray)
231188
rhs isa TracedRArray{T,N} && return rhs
@@ -254,103 +211,10 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
254211
)
255212
end
256213

257-
function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
258-
if isa(rhs, TracedRNumber)
259-
rhs isa TracedRNumber{T} && return rhs
260-
return TracedRNumber{T}(
261-
(),
262-
MLIR.IR.result(
263-
MLIR.Dialects.stablehlo.convert(
264-
rhs.mlir_data; result=mlir_type(TracedRNumber{T})
265-
),
266-
1,
267-
),
268-
)
269-
end
270-
if isa(rhs, Number)
271-
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
272-
return TracedRNumber{T}(
273-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
274-
)
275-
end
276-
T0 = eltype(rhs)
277-
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
278-
return promote_to(
279-
TracedRNumber{T},
280-
TracedRNumber{T0}(
281-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
282-
),
283-
)
284-
end
285-
286214
promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs)
287-
promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs)
288-
289-
for (jlop, hloop) in (
290-
(:(Base.min), :minimum),
291-
(:(Base.max), :maximum),
292-
(:(Base.:+), :add),
293-
(:(Base.:-), :subtract),
294-
(:(Base.:*), :multiply),
295-
(:(Base.:/), :divide),
296-
(:(Base.:^), :power),
297-
)
298-
@eval function $(jlop)(
299-
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
300-
) where {T}
301-
return TracedRNumber{T}(
302-
(),
303-
MLIR.IR.result(
304-
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
305-
),
306-
)
307-
end
308-
end
309-
310-
function Base.ifelse(
311-
@nospecialize(pred::TracedRNumber{Bool}),
312-
@nospecialize(x::TracedRNumber{T1}),
313-
@nospecialize(y::TracedRNumber{T2})
314-
) where {T1,T2}
315-
return TracedRNumber{promote_type(T1, T2)}(
316-
(),
317-
MLIR.IR.result(
318-
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
319-
),
320-
)
321-
end
322-
323-
function Base.literal_pow(
324-
::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
325-
) where {T,P}
326-
return Base.literal_pow(^, x, Val(P))
327-
end
328-
329-
for (jlop, hloop) in (
330-
(:(Base.abs), :abs),
331-
(:(Base.:-), :negate),
332-
(:(Base.sin), :sine),
333-
(:(Base.cos), :cosine),
334-
(:(Base.tanh), :tanh),
335-
(:(Base.FastMath.tanh_fast), :tanh),
336-
(:(Base.exp), :exponential),
337-
(:(Base.FastMath.exp_fast), :exponential),
338-
(:(Base.log), :log),
339-
(:(Base.sqrt), :sqrt),
340-
)
341-
@eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T}
342-
return TracedRNumber{T}(
343-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
344-
)
345-
end
346-
end
347215

348216
struct TypeCast{T<:Number} <: Function end
349217

350-
function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
351-
return promote_to(TracedRNumber{T}, x)
352-
end
353-
354218
elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x
355219
function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number}
356220
# Special Path to prevent going down a despecialized path
@@ -435,41 +299,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
435299
return traced2_result
436300
end
437301

438-
for (jlop, hloop, hlocomp, merge) in (
439-
(:(Base.:(==)), :compare, "EQ", :all),
440-
(:(Base.:(!=)), :compare, "NE", :any),
441-
(:(Base.:(>=)), :compare, "GE", nothing),
442-
(:(Base.:(>)), :compare, "GT", nothing),
443-
(:(Base.:(<=)), :compare, "LE", nothing),
444-
(:(Base.:(<)), :compare, "LT", nothing),
445-
)
446-
@eval function $(jlop)(
447-
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
448-
) where {T}
449-
return TracedRNumber{Bool}(
450-
(),
451-
MLIR.IR.result(
452-
MLIR.Dialects.stablehlo.$(hloop)(
453-
lhs.mlir_data,
454-
rhs.mlir_data;
455-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
456-
MLIR.IR.context(), $hlocomp
457-
),
458-
),
459-
1,
460-
),
461-
)
462-
end
463-
464-
if merge !== nothing
465-
@eval begin
466-
function $jlop(
467-
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
468-
) where {T,N}
469-
elems = $(jlop).(lhs, rhs)
470-
return N == 0 ? elems : $(merge)(elems)
471-
end
472-
end
302+
for (jlop, hloop, hlocomp, merge) in
303+
((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any))
304+
@eval function $jlop(
305+
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
306+
) where {T,N}
307+
elems = $(jlop).(lhs, rhs)
308+
return N == 0 ? elems : $(merge)(elems)
473309
end
474310
end
475311

src/TracedRNumber.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
mutable struct TracedRNumber{T} <: RNumber{T}
2+
paths::Tuple
3+
mlir_data::Union{Nothing,MLIR.IR.Value}
4+
5+
function TracedRNumber{T}(
6+
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
7+
) where {T}
8+
if !isnothing(mlir_data)
9+
@assert size(MLIR.IR.type(mlir_data)) == ()
10+
end
11+
return new{T}(paths, mlir_data)
12+
end
13+
end
14+
15+
Base.eltype(::Type{TracedRNumber{T}}) where {T} = T
16+
17+
Base.getindex(a::TracedRNumber{T}) where {T} = a
18+
19+
Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T))
20+
Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T))
21+
22+
function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T}
23+
return promote_to(TracedRNumber{T}, T(x))
24+
end
25+
26+
function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
27+
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
28+
end
29+
30+
Base.only(A::TracedRNumber{T}) where {T} = A
31+
32+
function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S}
33+
return TracedRNumber{Base.promote_type(T, S)}
34+
end
35+
36+
function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
37+
return TracedRNumber{Base.promote_type(T, S)}
38+
end
39+
40+
function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
41+
return promote_to(TracedRNumber{T}, x)
42+
end
43+
44+
function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
45+
if isa(rhs, TracedRNumber)
46+
rhs isa TracedRNumber{T} && return rhs
47+
return TracedRNumber{T}(
48+
(),
49+
MLIR.IR.result(
50+
MLIR.Dialects.stablehlo.convert(
51+
rhs.mlir_data; result=mlir_type(TracedRNumber{T})
52+
),
53+
1,
54+
),
55+
)
56+
end
57+
if isa(rhs, Number)
58+
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
59+
return TracedRNumber{T}(
60+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
61+
)
62+
end
63+
T0 = eltype(rhs)
64+
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
65+
return promote_to(
66+
TracedRNumber{T},
67+
TracedRNumber{T0}(
68+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
69+
),
70+
)
71+
end
72+
73+
promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs)
74+
75+
for (jlop, hloop) in (
76+
(:(Base.min), :minimum),
77+
(:(Base.max), :maximum),
78+
(:(Base.:+), :add),
79+
(:(Base.:-), :subtract),
80+
(:(Base.:*), :multiply),
81+
(:(Base.:/), :divide),
82+
(:(Base.:^), :power),
83+
)
84+
@eval function $(jlop)(
85+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
86+
) where {T}
87+
return TracedRNumber{T}(
88+
(),
89+
MLIR.IR.result(
90+
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
91+
),
92+
)
93+
end
94+
end
95+
96+
for (jlop, hloop, hlocomp) in (
97+
(:(Base.:(==)), :compare, "EQ"),
98+
(:(Base.:(!=)), :compare, "NE"),
99+
(:(Base.:(>=)), :compare, "GE"),
100+
(:(Base.:(>)), :compare, "GT"),
101+
(:(Base.:(<=)), :compare, "LE"),
102+
(:(Base.:(<)), :compare, "LT"),
103+
)
104+
@eval function $(jlop)(
105+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
106+
) where {T}
107+
return TracedRNumber{Bool}(
108+
(),
109+
MLIR.IR.result(
110+
MLIR.Dialects.stablehlo.$(hloop)(
111+
lhs.mlir_data,
112+
rhs.mlir_data;
113+
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
114+
MLIR.IR.context(), $hlocomp
115+
),
116+
),
117+
1,
118+
),
119+
)
120+
end
121+
end
122+
123+
function Base.ifelse(
124+
@nospecialize(pred::TracedRNumber{Bool}),
125+
@nospecialize(x::TracedRNumber{T1}),
126+
@nospecialize(y::TracedRNumber{T2})
127+
) where {T1,T2}
128+
return TracedRNumber{promote_type(T1, T2)}(
129+
(),
130+
MLIR.IR.result(
131+
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
132+
),
133+
)
134+
end
135+
136+
function Base.literal_pow(
137+
::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
138+
) where {T,P}
139+
return Base.literal_pow(^, x, Val(P))
140+
end
141+
142+
for (jlop, hloop) in (
143+
(:(Base.abs), :abs),
144+
(:(Base.:-), :negate),
145+
(:(Base.sin), :sine),
146+
(:(Base.cos), :cosine),
147+
(:(Base.tanh), :tanh),
148+
(:(Base.FastMath.tanh_fast), :tanh),
149+
(:(Base.exp), :exponential),
150+
(:(Base.FastMath.exp_fast), :exponential),
151+
(:(Base.log), :log),
152+
(:(Base.sqrt), :sqrt),
153+
)
154+
@eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T}
155+
return TracedRNumber{T}(
156+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
157+
)
158+
end
159+
end
160+
161+
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)

0 commit comments

Comments
 (0)