Skip to content

Commit 40d3402

Browse files
committed
test: more test fixes
1 parent 388e6a2 commit 40d3402

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import ..Reactant:
66
XLA,
77
ConcreteRArray,
88
TracedRArray,
9+
TracedRNumber,
910
OrderedIdDict,
1011
make_tracer,
1112
TracedToConcrete,
@@ -289,7 +290,7 @@ function compile_mlir!(mod, f, args; optimize=true)
289290
preserved_args = Tuple{TracedRArray,Int}[]
290291
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
291292
nresults = MLIR.IR.Value[]
292-
linear_results2 = TracedRArray[]
293+
linear_results2 = Union{TracedRArray,TracedRNumber}[]
293294
for (i, op) in enumerate(results)
294295
if !MLIR.IR.is_block_arg(op)
295296
push!(nresults, op)

src/TracedRNumber.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,45 @@ for (jlop, hloop, hlocomp) in (
112112
(:(Base.:(<=)), :compare, "LE"),
113113
(:(Base.:(<)), :compare, "LT"),
114114
)
115-
@eval function $(jlop)(
116-
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
117-
) where {T}
118-
return TracedRNumber{Bool}(
119-
(),
120-
MLIR.IR.result(
121-
MLIR.Dialects.stablehlo.$(hloop)(
122-
lhs.mlir_data,
123-
rhs.mlir_data;
124-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
125-
MLIR.IR.context(), $hlocomp
115+
@eval begin
116+
function $(jlop)(
117+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
118+
) where {T}
119+
return TracedRNumber{Bool}(
120+
(),
121+
MLIR.IR.result(
122+
MLIR.Dialects.stablehlo.$(hloop)(
123+
lhs.mlir_data,
124+
rhs.mlir_data;
125+
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
126+
MLIR.IR.context(), $hlocomp
127+
),
126128
),
129+
1,
127130
),
128-
1,
129-
),
130-
)
131+
)
132+
end
133+
134+
function $(jlop)(
135+
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)
136+
) where {T}
137+
return $(jlop)(lhs, promote_to(lhs, rhs))
138+
end
139+
140+
function $(jlop)(
141+
@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})
142+
) where {T}
143+
return $(jlop)(promote_to(rhs, lhs), rhs)
144+
end
145+
146+
function $(jlop)(
147+
@nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2})
148+
) where {T1,T2}
149+
commonTy = TracedRNumber{Base.promote_type(T1, T2)}
150+
lhs = promote_to(commonTy, lhs)
151+
rhs = promote_to(commonTy, rhs)
152+
return $(jlop)(lhs, rhs)
153+
end
131154
end
132155
end
133156

@@ -169,6 +192,9 @@ for (jlop, hloop) in (
169192
end
170193
end
171194

195+
# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined
196+
Base.abs2(x::TracedRNumber{<:Real}) = x^2
197+
172198
struct TypeCast{T<:ReactantPrimitives} <: Function end
173199

174200
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)

0 commit comments

Comments
 (0)