Skip to content

Commit 6e89952

Browse files
avik-palmofeing
andauthored
feat: implement a separate TracedRNumber (#161)
* feat: TracedRScalar * feat: partial progress on getting scalars to work * refactor: Scalar --> Number * fix: batching * fix: promote_rule and introduce union over primitive types * chore: apply formatting * feat: type-restrict arrays * refactor: move scalar ops to a separate file * feat: support Base.float * fix: import ordering * feat: handle `broadcast_preserving_zero_d` in a generic fashion * refactor: move code a bit * test: more test fixes * chore: apply formatting * fix: setindex with scalars * fix: scalar broadcasting case * feat: support BFloat16 from Core (if available) * test: more native lux functionality unblocked * refactor: use a union type for traced types * fix: check for reactant primitives * fix: missing import * fix: correct semantics for Colon mapreduce * fix: trace_type * fix: minor fixes * feat: support logsoftmax * fix: bool promote rule * fix: broadcasting of closures * refactor: use TracedTypes * Fix type of `preserved_args` * Rename `TracedTypes` to `TracedType` * small testset rename * fix: special handling for concatenation of numbers * Reenable tests * Rename `ReactantPrimitives` to `ReactantPrimitive` --------- Co-authored-by: Sergio Sánchez Ramírez <[email protected]> Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
1 parent f2c0e8a commit 6e89952

File tree

13 files changed

+557
-252
lines changed

13 files changed

+557
-252
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
4+
using Reactant:
5+
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
56

67
for (jlop, hloop) in (
78
(:(NNlib.tanh_fast), :tanh),
89
(:(NNlib.sigmoid_fast), :logistic),
910
(:(NNlib.sigmoid), :logistic),
1011
)
11-
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
12-
return TracedRArray{T,0}(
12+
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
13+
return TracedRNumber{T}(
1314
(),
1415
Reactant.MLIR.IR.result(
1516
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
1617
),
17-
(),
1818
)
1919
end
2020
end
2121

22-
# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
23-
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
24-
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
25-
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
26-
end
27-
end
28-
2922
# TODO handle non finite cases
3023
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
3124
max_ = NNlib.fast_maximum(x; dims)
@@ -39,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
3932
return out ./= tmp
4033
end
4134

35+
function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
36+
max_ = NNlib.fast_maximum(x; dims)
37+
# if all(isfinite, max_)
38+
@fastmath out .= x .- max_
39+
# else
40+
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
41+
# @. out = ifelse(
42+
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
43+
# )
44+
# end
45+
@fastmath log_ = log.(sum(exp, out; dims))
46+
return out .-= log_
47+
end
48+
4249
function NNlib.conv(
4350
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
4451
) where {T,N}

src/Compiler.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import ..Reactant:
66
XLA,
77
ConcreteRArray,
88
TracedRArray,
9+
TracedRNumber,
910
OrderedIdDict,
1011
make_tracer,
1112
TracedToConcrete,
12-
append_path
13+
append_path,
14+
TracedType
1315

1416
@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
1517

@@ -286,10 +288,10 @@ function compile_mlir!(mod, f, args; optimize=true)
286288
)
287289
end
288290

289-
preserved_args = Tuple{TracedRArray,Int}[]
291+
preserved_args = Tuple{TracedType,Int}[]
290292
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
291293
nresults = MLIR.IR.Value[]
292-
linear_results2 = TracedRArray[]
294+
linear_results2 = TracedType[]
293295
for (i, op) in enumerate(results)
294296
if !MLIR.IR.is_block_arg(op)
295297
push!(nresults, op)
@@ -573,7 +575,7 @@ end
573575
function compile_xla(f, args; client=nothing)
574576
# register MLIR dialects
575577
ctx = MLIR.IR.Context()
576-
Base.append!(Reactant.registry[]; context=ctx)
578+
append!(Reactant.registry[]; context=ctx)
577579
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
578580

579581
return MLIR.IR.context!(ctx) do

src/ConcreteRArray.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
7474
return to_float(x)
7575
end
7676

77-
function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2}
78-
return Base.promote_rule(T1, T2)
79-
end
80-
8177
for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
8278
@eval begin
8379
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}
@@ -158,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
158154
end
159155

160156
function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
161-
Base.setindex!(a, v, args...)
157+
setindex!(a, v, args...)
162158
return nothing
163159
end
164160

src/Reactant.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,45 @@ include("OrderedIdDict.jl")
77

88
using Enzyme
99

10-
abstract type RArray{T,N} <: AbstractArray{T,N} end
10+
@static if isdefined(Core, :BFloat16)
11+
const ReactantPrimitive = Union{
12+
Bool,
13+
Int8,
14+
UInt8,
15+
Int16,
16+
UInt16,
17+
Int32,
18+
UInt32,
19+
Int64,
20+
UInt64,
21+
Float16,
22+
Core.BFloat16,
23+
Float32,
24+
Float64,
25+
Complex{Float32},
26+
Complex{Float64},
27+
}
28+
else
29+
const ReactantPrimitive = Union{
30+
Bool,
31+
Int8,
32+
UInt8,
33+
Int16,
34+
UInt16,
35+
Int32,
36+
UInt32,
37+
Int64,
38+
UInt64,
39+
Float16,
40+
Float32,
41+
Float64,
42+
Complex{Float32},
43+
Complex{Float64},
44+
}
45+
end
46+
47+
abstract type RArray{T<:ReactantPrimitive,N} <: AbstractArray{T,N} end
48+
abstract type RNumber{T<:ReactantPrimitive} <: Number end
1149

1250
function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
1351
return reshape(A, Base._reshape_uncolon(A, dims))
@@ -45,8 +83,13 @@ include("mlir/MLIR.jl")
4583
include("XLA.jl")
4684
include("Interpreter.jl")
4785
include("utils.jl")
86+
4887
include("ConcreteRArray.jl")
88+
include("TracedRNumber.jl")
4989
include("TracedRArray.jl")
90+
91+
const TracedType = Union{TracedRArray,TracedRNumber}
92+
5093
include("Tracing.jl")
5194
include("Compiler.jl")
5295

0 commit comments

Comments
 (0)