Skip to content

Commit 74b0c7f

Browse files
committed
refactor: Scalar --> Number
1 parent 0b868ed commit 74b0c7f

File tree

5 files changed

+55
-55
lines changed

5 files changed

+55
-55
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ module ReactantNNlibExt
22

33
using NNlib
44
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR,
5-
TracedRScalar
5+
TracedRNumber
66

77
for (jlop, hloop) in (
88
(:(NNlib.tanh_fast), :tanh),
99
(:(NNlib.sigmoid_fast), :logistic),
1010
(:(NNlib.sigmoid), :logistic),
1111
)
12-
@eval function $(jlop)(x::TracedRScalar{T}) where {T}
13-
return TracedRScalar{T}(
12+
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
13+
return TracedRNumber{T}(
1414
(),
1515
Reactant.MLIR.IR.result(
1616
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ include("OrderedIdDict.jl")
88
using Enzyme
99

1010
abstract type RArray{T,N} <: AbstractArray{T,N} end
11-
abstract type RScalar{T} <: Number end
11+
abstract type RNumber{T} <: Number end
1212

1313
function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
1414
return reshape(A, Base._reshape_uncolon(A, dims))

src/TracedRArray.jl

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ function Base.setproperty!(x::TracedRArray, f::Symbol, v)
2626
return setfield!(x, f, v)
2727
end
2828

29-
mutable struct TracedRScalar{T} <: RScalar{T}
29+
mutable struct TracedRNumber{T} <: RNumber{T}
3030
paths::Tuple
3131
mlir_data::Union{Nothing,MLIR.IR.Value}
3232

33-
function TracedRScalar{T}(
33+
function TracedRNumber{T}(
3434
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
3535
) where {T}
3636
if !isnothing(mlir_data)
@@ -40,14 +40,14 @@ mutable struct TracedRScalar{T} <: RScalar{T}
4040
end
4141
end
4242

43-
function Base.setproperty!(x::TracedRScalar, f::Symbol, v)
43+
function Base.setproperty!(x::TracedRNumber, f::Symbol, v)
4444
if f === :mlir_data && !isnothing(v)
4545
@assert size(MLIR.IR.type(v)) == ()
4646
end
4747
return setfield!(x, f, v)
4848
end
4949

50-
Base.eltype(::Type{TracedRScalar{T}}) where {T} = T
50+
Base.eltype(::Type{TracedRNumber{T}}) where {T} = T
5151

5252
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
5353
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
@@ -69,13 +69,13 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
6969
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
7070
end
7171

72-
Base.getindex(a::TracedRScalar{T}) where {T} = a
72+
Base.getindex(a::TracedRNumber{T}) where {T} = a
7373

74-
Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
75-
Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T))
74+
Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T))
75+
Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T))
7676

77-
function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T}
78-
return promote_to(TracedRScalar{T}, T(x))
77+
function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T}
78+
return promote_to(TracedRNumber{T}, T(x))
7979
end
8080

8181
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@@ -102,7 +102,7 @@ and require expensive copies and synchronization each time and therefore should
102102
),
103103
1,
104104
)
105-
return TracedRScalar{T}((), res2)
105+
return TracedRNumber{T}((), res2)
106106
end
107107

108108
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
@@ -137,7 +137,7 @@ function Base.setindex!(
137137
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
138138
) where {T,N}
139139
indices = [
140-
(promote_to(TracedRScalar{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
140+
(promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
141141
i in indices
142142
]
143143
v = promote_to(TracedRArray{T,N}, v)
@@ -162,11 +162,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
162162
# return print(io, X.mlir_data, ")")
163163
end
164164

165-
function Base.show(io::IOty, X::TracedRScalar{T}) where {T,IOty<:Union{IO,IOContext}}
166-
return print(io, "TracedRScalar{", T, "}(", X.paths, ")")
165+
function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
166+
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")
167167
end
168168

169-
Base.only(A::TracedRScalar{T}) where {T} = A
169+
Base.only(A::TracedRNumber{T}) where {T} = A
170170

171171
function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT}
172172
if prod(dims) != prod(size(A))
@@ -238,12 +238,12 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
238238
return TracedRArray{Base.promote_type(T, S),N}
239239
end
240240

241-
function Base.promote_rule(::Type{T}, ::Type{TracedRScalar{S}}) where {T,S}
242-
return TracedRScalar{Base.promote_type(T, S)}
241+
function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
242+
return TracedRNumber{Base.promote_type(T, S)}
243243
end
244244

245-
function Base.convert(::Type{TracedRScalar{T}}, x::Number) where {T}
246-
return promote_to(TracedRScalar{T}, x)
245+
function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T}
246+
return promote_to(TracedRNumber{T}, x)
247247
end
248248

249249
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
@@ -262,7 +262,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
262262
end
263263
if isa(rhs, Number)
264264
throw(ArgumentError("Cannot promote number to `TracedRArray`. Use \
265-
`TracedRScalar` instead."))
265+
`TracedRNumber` instead."))
266266
end
267267
T0 = eltype(rhs)
268268
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
@@ -274,37 +274,37 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
274274
)
275275
end
276276

277-
function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
278-
if isa(rhs, TracedRScalar)
279-
rhs isa TracedRScalar{T} && return rhs
280-
return TracedRScalar{T}(
277+
function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
278+
if isa(rhs, TracedRNumber)
279+
rhs isa TracedRNumber{T} && return rhs
280+
return TracedRNumber{T}(
281281
(),
282282
MLIR.IR.result(
283283
MLIR.Dialects.stablehlo.convert(
284-
rhs.mlir_data; result=mlir_type(TracedRScalar{T})
284+
rhs.mlir_data; result=mlir_type(TracedRNumber{T})
285285
),
286286
1,
287287
),
288288
)
289289
end
290290
if isa(rhs, Number)
291-
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRScalar{T}))
292-
return TracedRScalar{T}(
291+
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
292+
return TracedRNumber{T}(
293293
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
294294
)
295295
end
296296
T0 = eltype(rhs)
297297
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
298298
return promote_to(
299-
TracedRScalar{T},
300-
TracedRScalar{T0}(
299+
TracedRNumber{T},
300+
TracedRNumber{T0}(
301301
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
302302
),
303303
)
304304
end
305305

306306
promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs)
307-
promote_to(::TracedRScalar{T}, rhs) where {T} = promote_to(TracedRScalar{T}, rhs)
307+
promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs)
308308

309309
for (jlop, hloop) in (
310310
(:(Base.min), :minimum),
@@ -316,7 +316,7 @@ for (jlop, hloop) in (
316316
(:(Base.:^), :power),
317317
)
318318
@eval function $(jlop)(
319-
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
319+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
320320
) where {T}
321321
return TracedRArray{T}(
322322
(),
@@ -328,22 +328,22 @@ for (jlop, hloop) in (
328328
end
329329

330330
function Base.ifelse(
331-
@nospecialize(pred::TracedRScalar{Bool}),
332-
@nospecialize(x::TracedRScalar{T1}),
333-
@nospecialize(y::TracedRScalar{T2})
331+
@nospecialize(pred::TracedRNumber{Bool}),
332+
@nospecialize(x::TracedRNumber{T1}),
333+
@nospecialize(y::TracedRNumber{T2})
334334
) where {T1,T2}
335-
return TracedRScalar{promote_type(T1, T2)}(
335+
return TracedRNumber{promote_type(T1, T2)}(
336336
(),
337337
MLIR.IR.result(
338338
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
339339
),
340340
)
341341
end
342342

343-
Base.abs2(x::Reactant.TracedRScalar{T}) where {T} = x * conj(x)
343+
Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x)
344344

345345
function Base.literal_pow(
346-
::Base.RefValue{typeof(^)}, x::TracedRScalar{T}, ::Base.RefValue{Val{P}}
346+
::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
347347
) where {T,P}
348348
return Base.literal_pow(^, x, Val(P))
349349
end
@@ -360,8 +360,8 @@ for (jlop, hloop) in (
360360
(:(Base.log), :log),
361361
(:(Base.sqrt), :sqrt),
362362
)
363-
@eval function $(jlop)(@nospecialize(lhs::TracedRScalar{T})) where {T}
364-
return TracedRScalar{T}(
363+
@eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T}
364+
return TracedRNumber{T}(
365365
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
366366
)
367367
end
@@ -467,9 +467,9 @@ for (jlop, hloop, hlocomp, merge) in (
467467
(:(Base.:(<)), :compare, "LT", nothing),
468468
)
469469
@eval function $(jlop)(
470-
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
470+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
471471
) where {T}
472-
return TracedRScalar{Bool}(
472+
return TracedRNumber{Bool}(
473473
(),
474474
MLIR.IR.result(
475475
MLIR.Dialects.stablehlo.$(hloop)(
@@ -571,7 +571,7 @@ function Base.mapreduce(
571571
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
572572

573573
args = (
574-
TracedRScalar{T}((), MLIR.IR.argument(fnbody, i), ()) for
574+
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for
575575
(i, ty) in enumerate(in_tys)
576576
)
577577

src/Tracing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ for T in (
1616
Integer,
1717
AbstractString,
1818
RArray,
19-
RScalar,
19+
RNumber,
2020
)
2121
@eval function traced_type(::Type{T}, seen, mode) where {T<:$T}
2222
return T
@@ -331,7 +331,7 @@ function make_tracer(
331331
return seen[prev]
332332
end
333333
res = if toscalar
334-
TracedRScalar{T}((path,), nothing)
334+
TracedRNumber{T}((path,), nothing)
335335
elseif !isnothing(tobatch)
336336
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
337337
else
@@ -355,7 +355,7 @@ end
355355

356356
function make_tracer(
357357
seen,
358-
@nospecialize(prev::TracedRScalar{T}),
358+
@nospecialize(prev::TracedRNumber{T}),
359359
@nospecialize(path),
360360
mode;
361361
kwargs...
@@ -374,7 +374,7 @@ function make_tracer(
374374
if haskey(seen, prev)
375375
return seen[prev]
376376
end
377-
res = TracedRScalar{T}((path,), prev.mlir_data)
377+
res = TracedRNumber{T}((path,), prev.mlir_data)
378378
seen[prev] = res
379379
return res
380380
end

src/utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ function mlir_type(x::RArray{T,N}) where {T,N}
22
return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T))
33
end
44

5-
mlir_type(::RScalar{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T))
5+
mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T))
66

77
function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N}
88
@assert length(shape) == N
99
return MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
1010
end
1111

12-
function mlir_type(::Type{<:RScalar{T}}) where {T}
12+
function mlir_type(::Type{<:RNumber{T}}) where {T}
1313
return MLIR.IR.TensorType((), MLIR.IR.Type(T))
1414
end
1515

@@ -44,9 +44,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
4444
)
4545
end
4646

47-
linear_args = Union{TracedRArray,TracedRScalar}[]
47+
linear_args = Union{TracedRArray,TracedRNumber}[]
4848
for (k, v) in seen_args
49-
if !(v isa TracedRArray) && !(v isa TracedRScalar)
49+
if !(v isa TracedRArray) && !(v isa TracedRNumber)
5050
continue
5151
end
5252
push!(linear_args, v)
@@ -127,10 +127,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
127127
)
128128
end
129129

130-
linear_results = Union{TracedRArray,TracedRScalar}[]
130+
linear_results = Union{TracedRArray,TracedRNumber}[]
131131

132132
for (k, v) in seen_results
133-
if !(v isa TracedRArray) && !(v isa TracedRScalar)
133+
if !(v isa TracedRArray) && !(v isa TracedRNumber)
134134
continue
135135
end
136136

0 commit comments

Comments
 (0)