@@ -112,22 +112,45 @@ for (jlop, hloop, hlocomp) in (
112
112
(:(Base.:(<= )), :compare , " LE" ),
113
113
(:(Base.:(< )), :compare , " LT" ),
114
114
)
115
- @eval function $ (jlop)(
116
- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
117
- ) where {T}
118
- return TracedRNumber {Bool} (
119
- (),
120
- MLIR. IR. result (
121
- MLIR. Dialects. stablehlo.$ (hloop)(
122
- lhs. mlir_data,
123
- rhs. mlir_data;
124
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
125
- MLIR. IR. context (), $ hlocomp
115
+ @eval begin
116
+ function $ (jlop)(
117
+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
118
+ ) where {T}
119
+ return TracedRNumber {Bool} (
120
+ (),
121
+ MLIR. IR. result (
122
+ MLIR. Dialects. stablehlo.$ (hloop)(
123
+ lhs. mlir_data,
124
+ rhs. mlir_data;
125
+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
126
+ MLIR. IR. context (), $ hlocomp
127
+ ),
126
128
),
129
+ 1 ,
127
130
),
128
- 1 ,
129
- ),
130
- )
131
+ )
132
+ end
133
+
134
+ function $ (jlop)(
135
+ @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs)
136
+ ) where {T}
137
+ return $ (jlop)(lhs, promote_to (lhs, rhs))
138
+ end
139
+
140
+ function $ (jlop)(
141
+ @nospecialize (lhs), @nospecialize (rhs:: TracedRNumber{T} )
142
+ ) where {T}
143
+ return $ (jlop)(promote_to (rhs, lhs), rhs)
144
+ end
145
+
146
+ function $ (jlop)(
147
+ @nospecialize (lhs:: TracedRNumber{T1} ), @nospecialize (rhs:: TracedRNumber{T2} )
148
+ ) where {T1,T2}
149
+ commonTy = TracedRNumber{Base. promote_type (T1, T2)}
150
+ lhs = promote_to (commonTy, lhs)
151
+ rhs = promote_to (commonTy, rhs)
152
+ return $ (jlop)(lhs, rhs)
153
+ end
131
154
end
132
155
end
133
156
@@ -169,6 +192,9 @@ for (jlop, hloop) in (
169
192
end
170
193
end
171
194
195
+ # XXX : Enzyme-MLIR doesn't have `abs` adjoint defined
196
+ Base. abs2 (x:: TracedRNumber{<:Real} ) = x^ 2
197
+
172
198
struct TypeCast{T<: ReactantPrimitives } <: Function end
173
199
174
200
(:: TypeCast{T} )(x:: TracedRNumber{T2} ) where {T,T2} = promote_to (TracedRNumber{T}, x)
0 commit comments