Skip to content

Commit 0b868ed

Browse files
committed
feat: partial progress on getting scalars to work
1 parent 10d364c commit 0b868ed

File tree

4 files changed

+111
-108
lines changed

4 files changed

+111
-108
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR,
5+
TracedRScalar
56

67
for (jlop, hloop) in (
78
(:(NNlib.tanh_fast), :tanh),
89
(:(NNlib.sigmoid_fast), :logistic),
910
(:(NNlib.sigmoid), :logistic),
1011
)
11-
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
12-
return TracedRArray{T,0}(
12+
@eval function $(jlop)(x::TracedRScalar{T}) where {T}
13+
return TracedRScalar{T}(
1314
(),
1415
Reactant.MLIR.IR.result(
1516
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
1617
),
17-
(),
1818
)
1919
end
2020
end
2121

22-
# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
23-
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
24-
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
25-
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
26-
end
27-
end
28-
2922
# TODO handle non finite cases
3023
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
3124
max_ = NNlib.fast_maximum(x; dims)

src/TracedRArray.jl

Lines changed: 63 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22+
function Base.setproperty!(x::TracedRArray, f::Symbol, v)
23+
if f === :mlir_data && !isnothing(v)
24+
@assert size(MLIR.IR.type(v)) == size(x)
25+
end
26+
return setfield!(x, f, v)
27+
end
28+
2229
mutable struct TracedRScalar{T} <: RScalar{T}
2330
paths::Tuple
2431
mlir_data::Union{Nothing,MLIR.IR.Value}
@@ -33,6 +40,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
3340
end
3441
end
3542

43+
function Base.setproperty!(x::TracedRScalar, f::Symbol, v)
44+
if f === :mlir_data && !isnothing(v)
45+
@assert size(MLIR.IR.type(v)) == ()
46+
end
47+
return setfield!(x, f, v)
48+
end
49+
50+
Base.eltype(::Type{TracedRScalar{T}}) where {T} = T
51+
3652
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3753
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
3854
const AnyTracedRVector{T} = AnyTracedRArray{T,1}
@@ -59,7 +75,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
5975
Base.one(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, one(T))
6076

6177
function Base.convert(::Type{<:TracedRScalar{T}}, x::Number) where {T}
62-
return promote_to(TracedRArray{T,0}, T(x))
78+
return promote_to(TracedRScalar{T}, T(x))
6379
end
6480

6581
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
@@ -121,7 +137,7 @@ function Base.setindex!(
121137
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
122138
) where {T,N}
123139
indices = [
124-
(promote_to(TracedRArray{Int,0}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
140+
(promote_to(TracedRScalar{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
125141
i in indices
126142
]
127143
v = promote_to(TracedRArray{T,N}, v)
@@ -222,6 +238,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
222238
return TracedRArray{Base.promote_type(T, S),N}
223239
end
224240

241+
function Base.promote_rule(::Type{T}, ::Type{TracedRScalar{S}}) where {T,S}
242+
return TracedRScalar{Base.promote_type(T, S)}
243+
end
244+
245+
function Base.convert(::Type{TracedRScalar{T}}, x::Number) where {T}
246+
return promote_to(TracedRScalar{T}, x)
247+
end
248+
225249
function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
226250
if isa(rhs, TracedRArray)
227251
rhs isa TracedRArray{T,N} && return rhs
@@ -279,12 +303,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
279303
)
280304
end
281305

282-
function promote_to(::TracedRArray{T,N}, rhs) where {T,N}
283-
return promote_to(TracedRArray{T,N}, rhs)
284-
end
285-
function promote_to(::TracedRScalar{T}, rhs) where {T}
286-
return promote_to(TracedRScalar{T}, rhs)
287-
end
306+
promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs)
307+
promote_to(::TracedRScalar{T}, rhs) where {T} = promote_to(TracedRScalar{T}, rhs)
288308

289309
for (jlop, hloop) in (
290310
(:(Base.min), :minimum),
@@ -295,66 +315,35 @@ for (jlop, hloop) in (
295315
(:(Base.:/), :divide),
296316
(:(Base.:^), :power),
297317
)
298-
@eval begin
299-
function $(jlop)(
300-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
301-
) where {T}
302-
return TracedRArray{T,0}(
303-
(),
304-
MLIR.IR.result(
305-
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
306-
),
307-
(),
308-
)
309-
end
310-
311-
function $(jlop)(
312-
@nospecialize(lhs::TracedRArray{T1,0}), @nospecialize(rhs::TracedRArray{T2,0})
313-
) where {T1,T2}
314-
commonTy = TracedRArray{Base.promote_type(T1, T2),0}
315-
lhs = promote_to(commonTy, lhs)
316-
rhs = promote_to(commonTy, rhs)
317-
return $(jlop)(lhs, rhs)
318-
end
319-
end
320-
321-
for otherType in (Number, Any)
322-
@eval begin
323-
function $(jlop)(
324-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::$(otherType))
325-
) where {T}
326-
rhs = promote_to(lhs, rhs)
327-
return $(jlop)(lhs, rhs)
328-
end
329-
330-
function $(jlop)(
331-
@nospecialize(lhs::$(otherType)), @nospecialize(rhs::TracedRArray{T,0})
332-
) where {T}
333-
lhs = promote_to(rhs, lhs)
334-
return $(jlop)(lhs, rhs)
335-
end
336-
end
318+
@eval function $(jlop)(
319+
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
320+
) where {T}
321+
return TracedRArray{T}(
322+
(),
323+
MLIR.IR.result(
324+
MLIR.Dialects.stablehlo.$(hloop)(lhs.mlir_data, rhs.mlir_data), 1
325+
),
326+
)
337327
end
338328
end
339329

340330
function Base.ifelse(
341-
@nospecialize(pred::TracedRArray{Bool,0}),
342-
@nospecialize(x::TracedRArray{T1,0}),
343-
@nospecialize(y::TracedRArray{T2,0})
331+
@nospecialize(pred::TracedRScalar{Bool}),
332+
@nospecialize(x::TracedRScalar{T1}),
333+
@nospecialize(y::TracedRScalar{T2})
344334
) where {T1,T2}
345-
return TracedRArray{promote_type(T1, T2),0}(
335+
return TracedRScalar{promote_type(T1, T2)}(
346336
(),
347337
MLIR.IR.result(
348338
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
349339
),
350-
size(pred),
351340
)
352341
end
353342

354-
Base.abs2(x::Reactant.TracedRArray{T,0}) where {T} = x * conj(x)
343+
Base.abs2(x::Reactant.TracedRScalar{T}) where {T} = x * conj(x)
355344

356345
function Base.literal_pow(
357-
::Base.RefValue{typeof(^)}, x::TracedRArray{T,0}, ::Base.RefValue{Val{P}}
346+
::Base.RefValue{typeof(^)}, x::TracedRScalar{T}, ::Base.RefValue{Val{P}}
358347
) where {T,P}
359348
return Base.literal_pow(^, x, Val(P))
360349
end
@@ -371,14 +360,10 @@ for (jlop, hloop) in (
371360
(:(Base.log), :log),
372361
(:(Base.sqrt), :sqrt),
373362
)
374-
@eval begin
375-
function $jlop(@nospecialize(lhs::TracedRArray{T,0})) where {T}
376-
return TracedRArray{T,0}(
377-
(),
378-
MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1),
379-
size(lhs),
380-
)
381-
end
363+
@eval function $(jlop)(@nospecialize(lhs::TracedRScalar{T})) where {T}
364+
return TracedRScalar{T}(
365+
(), MLIR.IR.result(MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1)
366+
)
382367
end
383368
end
384369

@@ -445,6 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
445430
residx = 1
446431

447432
for a in linear_results
433+
@show a
448434
if has_residx(a)
449435
path = get_residx(a)
450436
set!(result, path[2:end], MLIR.IR.result(res, residx))
@@ -480,37 +466,22 @@ for (jlop, hloop, hlocomp, merge) in (
480466
(:(Base.:(<=)), :compare, "LE", nothing),
481467
(:(Base.:(<)), :compare, "LT", nothing),
482468
)
483-
@eval begin
484-
function $(jlop)(
485-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
486-
) where {T}
487-
return TracedRArray{Bool,0}(
488-
(),
489-
MLIR.IR.result(
490-
MLIR.Dialects.stablehlo.$hloop(
491-
lhs.mlir_data,
492-
rhs.mlir_data;
493-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
494-
MLIR.IR.context(), $hlocomp
495-
),
469+
@eval function $(jlop)(
470+
@nospecialize(lhs::TracedRScalar{T}), @nospecialize(rhs::TracedRScalar{T})
471+
) where {T}
472+
return TracedRScalar{Bool}(
473+
(),
474+
MLIR.IR.result(
475+
MLIR.Dialects.stablehlo.$(hloop)(
476+
lhs.mlir_data,
477+
rhs.mlir_data;
478+
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
479+
MLIR.IR.context(), $hlocomp
496480
),
497-
1,
498481
),
499-
size(lhs),
500-
)
501-
end
502-
503-
function $(jlop)(
504-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs)
505-
) where {T}
506-
return $(jlop)(lhs, promote_to(lhs, rhs))
507-
end
508-
509-
function $(jlop)(
510-
@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0})
511-
) where {T}
512-
return $(jlop)(promote_to(rhs, lhs), rhs)
513-
end
482+
1,
483+
),
484+
)
514485
end
515486

516487
if merge !== nothing
@@ -600,7 +571,7 @@ function Base.mapreduce(
600571
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
601572

602573
args = (
603-
TracedRArray{T,0}((), MLIR.IR.argument(fnbody, i), ()) for
574+
TracedRScalar{T}((), MLIR.IR.argument(fnbody, i), ()) for
604575
(i, ty) in enumerate(in_tys)
605576
)
606577

src/Tracing.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ for T in (
1616
Integer,
1717
AbstractString,
1818
RArray,
19+
RScalar,
1920
)
2021
@eval function traced_type(::Type{T}, seen, mode) where {T<:$T}
2122
return T
@@ -330,7 +331,7 @@ function make_tracer(
330331
return seen[prev]
331332
end
332333
res = if toscalar
333-
TracedRArray{T,0}((path,), nothing, ())
334+
TracedRScalar{T}((path,), nothing)
334335
elseif !isnothing(tobatch)
335336
TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch)
336337
else
@@ -352,6 +353,44 @@ function make_tracer(
352353
throw("Cannot Unknown trace mode $mode")
353354
end
354355

356+
function make_tracer(
357+
seen,
358+
@nospecialize(prev::TracedRScalar{T}),
359+
@nospecialize(path),
360+
mode;
361+
kwargs...
362+
) where {T}
363+
if mode == ConcreteToTraced
364+
throw("Cannot trace existing trace type")
365+
end
366+
if mode == TracedTrack
367+
prev.paths = (prev.paths..., path)
368+
if !haskey(seen, prev)
369+
return seen[prev] = prev
370+
end
371+
return prev
372+
end
373+
if mode == TracedSetPath
374+
if haskey(seen, prev)
375+
return seen[prev]
376+
end
377+
res = TracedRScalar{T}((path,), prev.mlir_data)
378+
seen[prev] = res
379+
return res
380+
end
381+
382+
if mode == TracedToConcrete
383+
if haskey(seen, prev)
384+
return seen[prev]::ConcreteRArray{T,0}
385+
end
386+
res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev))
387+
seen[prev] = res
388+
return res
389+
end
390+
391+
throw("Cannot Unknown trace mode $mode")
392+
end
393+
355394
function make_tracer(
356395
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
357396
) where {RT<:AbstractFloat}

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
4444
)
4545
end
4646

47-
linear_args = TracedRArray[]
47+
linear_args = Union{TracedRArray,TracedRScalar}[]
4848
for (k, v) in seen_args
49-
if !(v isa TracedRArray)
49+
if !(v isa TracedRArray) && !(v isa TracedRScalar)
5050
continue
5151
end
5252
push!(linear_args, v)
@@ -127,10 +127,10 @@ function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=fa
127127
)
128128
end
129129

130-
linear_results = TracedRArray[]
130+
linear_results = Union{TracedRArray,TracedRScalar}[]
131131

132132
for (k, v) in seen_results
133-
if !(v isa TracedRArray)
133+
if !(v isa TracedRArray) && !(v isa TracedRScalar)
134134
continue
135135
end
136136

0 commit comments

Comments
 (0)