Skip to content

Commit 41291a9

Browse files
Improved performance of the fp4tofp conversion
Use a simple lookup table instead of explicit conversion. Fixes #4298
1 parent 6b2fa6c commit 41291a9

File tree

1 file changed

+32
-79
lines changed

1 file changed

+32
-79
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/Fp4ToFpOpToLLVM.cpp

Lines changed: 32 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -13,80 +13,6 @@ using namespace mlir::triton::gpu;
1313
using namespace mlir::triton::gpu::intel;
1414

1515
namespace {
16-
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
17-
ArrayRef<Value> values) {
18-
auto b = TritonLLVMOpBuilder(loc, rewriter);
19-
SmallVector<Value> results;
20-
for (auto v : values) {
21-
auto em0 = b.and_(v, b.i8_val(0x7));
22-
auto em1 = b.and_(v, b.i8_val(0x70));
23-
Value v0 =
24-
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(6)),
25-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
26-
Value v1 =
27-
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(2)),
28-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
29-
// Three cases:
30-
// 1) x is normal and non-zero: Correct bias
31-
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
32-
b.add(v0, b.i16_val((127 - 1) << 7)), v0);
33-
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
34-
b.add(v1, b.i16_val((127 - 1) << 7)), v1);
35-
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
36-
// bf16
37-
v0 = b.bitcast(
38-
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
39-
b.or_(b.i16_val(16128), b.and_(v0, b.i16_val(0x8000))), v0),
40-
bf16_ty);
41-
v1 = b.bitcast(
42-
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
43-
b.or_(b.i16_val(16128), b.and_(v1, b.i16_val(0x8000))), v1),
44-
bf16_ty);
45-
// 3) x is zero, nothing to do
46-
results.push_back(v0);
47-
results.push_back(v1);
48-
}
49-
return results;
50-
}
51-
52-
SmallVector<Value> convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc,
53-
ArrayRef<Value> values) {
54-
auto b = TritonLLVMOpBuilder(loc, rewriter);
55-
SmallVector<Value> results;
56-
for (auto v : values) {
57-
auto em0 = b.and_(v, b.i8_val(0x7));
58-
auto em1 = b.and_(v, b.i8_val(0x70));
59-
// FP16 bits: sign = 1, exponent = 5, mantissa = 10
60-
Value v0 =
61-
b.or_(b.shl(b.zext(i16_ty, em0), b.i16_val(10 - 1)),
62-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x8))), b.i16_val(12)));
63-
Value v1 =
64-
b.or_(b.shl(b.zext(i16_ty, em1), b.i16_val(10 - 1 - 4)),
65-
b.shl(b.zext(i16_ty, b.and_(v, b.i8_val(0x80))), b.i16_val(8)));
66-
67-
// Three cases:
68-
// 1) x is normal and non-zero: Correct bias
69-
v0 = b.select(b.icmp_ne(b.and_(em0, b.i8_val(0x6)), b.i8_val(0)),
70-
b.add(v0, b.i16_val((15 - 1) << 10)), v0);
71-
v1 = b.select(b.icmp_ne(b.and_(em1, b.i8_val(0x60)), b.i8_val(0)),
72-
b.add(v1, b.i16_val((15 - 1) << 10)), v1);
73-
74-
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
75-
v0 = b.bitcast(
76-
b.select(b.icmp_eq(em0, b.i8_val(0x1)),
77-
b.or_(b.i16_val(0x3800), b.and_(v0, b.i16_val(0x8000))), v0),
78-
f16_ty);
79-
v1 = b.bitcast(
80-
b.select(b.icmp_eq(em1, b.i8_val(0x10)),
81-
b.or_(b.i16_val(0x3800), b.and_(v1, b.i16_val(0x8000))), v1),
82-
f16_ty);
83-
// 3) x is zero, nothing to do
84-
results.push_back(v0);
85-
results.push_back(v1);
86-
}
87-
return results;
88-
}
89-
9016
class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9117
public:
9218
Fp4ToFpOpPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit)
@@ -96,21 +22,48 @@ class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9622
matchAndRewrite(Fp4ToFpOp op, OpAdaptor adaptor,
9723
ConversionPatternRewriter &rewriter) const override {
9824
Location loc = op.getLoc();
99-
auto *ctx = op.getContext();
10025
Type elemType = op.getType().getElementType();
10126
assert(elemType == f16_ty || elemType == bf16_ty);
102-
bool toFp16 = elemType == f16_ty;
10327

10428
SmallVector<Value> xVals =
10529
unpackLLElements(loc, adaptor.getSrc(), rewriter);
106-
xVals = toFp16 ? convertMxfp4x2ToFp16x2(rewriter, loc, xVals)
107-
: convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
108-
30+
xVals = convertMxfp4x2ToFloat(rewriter, loc, xVals,
31+
elemType == f16_ty ? f16_ty : bf16_ty);
10932
Value result =
11033
packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType());
11134
rewriter.replaceOp(op, result);
11235
return success();
11336
}
37+
38+
private:
39+
static SmallVector<Value> convertMxfp4x2ToFloat(RewriterBase &rewriter, Location loc,
40+
ArrayRef<Value> values,
41+
FloatType floatTy) {
42+
Value table;
43+
{ // Create a constant vector containing all the possible values
44+
auto vecTy = VectorType::get({16}, floatTy);
45+
SmallVector<Attribute, 16> values;
46+
for (double v : {0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5,
47+
-2., -3., -4., -6.})
48+
values.push_back(rewriter.getFloatAttr(floatTy, v));
49+
table = rewriter.create<LLVM::ConstantOp>(
50+
loc, vecTy, DenseElementsAttr::get(vecTy, values));
51+
}
52+
53+
TritonLLVMOpBuilder b(loc, rewriter);
54+
Value i8_4 = b.i8_val(4);
55+
Value i8_15 = b.i8_val(15);
56+
SmallVector<Value> results;
57+
results.reserve(values.size() * 2);
58+
for (Value v : values) {
59+
// The first and last 4 bits are the values indices in the table
60+
Value idx1 = b.and_(v, i8_15);
61+
Value idx2 = b.lshr(v, i8_4);
62+
results.push_back(b.extract_element(table, idx1));
63+
results.push_back(b.extract_element(table, idx2));
64+
}
65+
return results;
66+
}
11467
};
11568
} // anonymous namespace
11669

0 commit comments

Comments
 (0)