@@ -13,80 +13,6 @@ using namespace mlir::triton::gpu;
1313using namespace mlir ::triton::gpu::intel;
1414
1515namespace {
16- SmallVector<Value> convertMxfp4x2ToBf16x2 (RewriterBase &rewriter, Location loc,
17- ArrayRef<Value> values) {
18- auto b = TritonLLVMOpBuilder (loc, rewriter);
19- SmallVector<Value> results;
20- for (auto v : values) {
21- auto em0 = b.and_ (v, b.i8_val (0x7 ));
22- auto em1 = b.and_ (v, b.i8_val (0x70 ));
23- Value v0 =
24- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (6 )),
25- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
26- Value v1 =
27- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (2 )),
28- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
29- // Three cases:
30- // 1) x is normal and non-zero: Correct bias
31- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
32- b.add (v0, b.i16_val ((127 - 1 ) << 7 )), v0);
33- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
34- b.add (v1, b.i16_val ((127 - 1 ) << 7 )), v1);
35- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
36- // bf16
37- v0 = b.bitcast (
38- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
39- b.or_ (b.i16_val (16128 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
40- bf16_ty);
41- v1 = b.bitcast (
42- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
43- b.or_ (b.i16_val (16128 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
44- bf16_ty);
45- // 3) x is zero, nothing to do
46- results.push_back (v0);
47- results.push_back (v1);
48- }
49- return results;
50- }
51-
52- SmallVector<Value> convertMxfp4x2ToFp16x2 (RewriterBase &rewriter, Location loc,
53- ArrayRef<Value> values) {
54- auto b = TritonLLVMOpBuilder (loc, rewriter);
55- SmallVector<Value> results;
56- for (auto v : values) {
57- auto em0 = b.and_ (v, b.i8_val (0x7 ));
58- auto em1 = b.and_ (v, b.i8_val (0x70 ));
59- // FP16 bits: sign = 1, exponent = 5, mantissa = 10
60- Value v0 =
61- b.or_ (b.shl (b.zext (i16_ty, em0), b.i16_val (10 - 1 )),
62- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x8 ))), b.i16_val (12 )));
63- Value v1 =
64- b.or_ (b.shl (b.zext (i16_ty, em1), b.i16_val (10 - 1 - 4 )),
65- b.shl (b.zext (i16_ty, b.and_ (v, b.i8_val (0x80 ))), b.i16_val (8 )));
66-
67- // Three cases:
68- // 1) x is normal and non-zero: Correct bias
69- v0 = b.select (b.icmp_ne (b.and_ (em0, b.i8_val (0x6 )), b.i8_val (0 )),
70- b.add (v0, b.i16_val ((15 - 1 ) << 10 )), v0);
71- v1 = b.select (b.icmp_ne (b.and_ (em1, b.i8_val (0x60 )), b.i8_val (0 )),
72- b.add (v1, b.i16_val ((15 - 1 ) << 10 )), v1);
73-
74- // 2) x is subnormal (x == 0bs001 where s is the sign): Map to fp16 +-0.5
75- v0 = b.bitcast (
76- b.select (b.icmp_eq (em0, b.i8_val (0x1 )),
77- b.or_ (b.i16_val (0x3800 ), b.and_ (v0, b.i16_val (0x8000 ))), v0),
78- f16_ty);
79- v1 = b.bitcast (
80- b.select (b.icmp_eq (em1, b.i8_val (0x10 )),
81- b.or_ (b.i16_val (0x3800 ), b.and_ (v1, b.i16_val (0x8000 ))), v1),
82- f16_ty);
83- // 3) x is zero, nothing to do
84- results.push_back (v0);
85- results.push_back (v1);
86- }
87- return results;
88- }
89-
9016class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern <Fp4ToFpOp> {
9117public:
9218 Fp4ToFpOpPattern (LLVMTypeConverter &typeConverter, PatternBenefit benefit)
@@ -96,21 +22,48 @@ class Fp4ToFpOpPattern : public ConvertOpToLLVMPattern<Fp4ToFpOp> {
9622 matchAndRewrite (Fp4ToFpOp op, OpAdaptor adaptor,
9723 ConversionPatternRewriter &rewriter) const override {
9824 Location loc = op.getLoc ();
99- auto *ctx = op.getContext ();
10025 Type elemType = op.getType ().getElementType ();
10126 assert (elemType == f16_ty || elemType == bf16_ty);
102- bool toFp16 = elemType == f16_ty;
10327
10428 SmallVector<Value> xVals =
10529 unpackLLElements (loc, adaptor.getSrc (), rewriter);
106- xVals = toFp16 ? convertMxfp4x2ToFp16x2 (rewriter, loc, xVals)
107- : convertMxfp4x2ToBf16x2 (rewriter, loc, xVals);
108-
30+ xVals = convertMxfp4x2ToFloat (rewriter, loc, xVals,
31+ elemType == f16_ty ? f16_ty : bf16_ty);
10932 Value result =
11033 packLLElements (loc, getTypeConverter (), xVals, rewriter, op.getType ());
11134 rewriter.replaceOp (op, result);
11235 return success ();
11336 }
37+
38+ private:
39+ static SmallVector<Value> convertMxfp4x2ToFloat (RewriterBase &rewriter, Location loc,
40+ ArrayRef<Value> values,
41+ FloatType floatTy) {
42+ Value table;
43+ { // Create a constant vector containing all the possible values
44+ auto vecTy = VectorType::get ({16 }, floatTy);
45+ SmallVector<Attribute, 16 > values;
46+ for (double v : {0 ., 0.5 , 1 ., 1.5 , 2 ., 3 ., 4 ., 6 ., -0 ., -0.5 , -1 ., -1.5 ,
47+ -2 ., -3 ., -4 ., -6 .})
48+ values.push_back (rewriter.getFloatAttr (floatTy, v));
49+ table = rewriter.create <LLVM::ConstantOp>(
50+ loc, vecTy, DenseElementsAttr::get (vecTy, values));
51+ }
52+
53+ TritonLLVMOpBuilder b (loc, rewriter);
54+ Value i8_4 = b.i8_val (4 );
55+ Value i8_15 = b.i8_val (15 );
56+ SmallVector<Value> results;
57+ results.reserve (values.size () * 2 );
58+ for (Value v : values) {
59+ // The first and last 4 bits are the values indices in the table
60+ Value idx1 = b.and_ (v, i8_15);
61+ Value idx2 = b.lshr (v, i8_4);
62+ results.push_back (b.extract_element (table, idx1));
63+ results.push_back (b.extract_element (table, idx2));
64+ }
65+ return results;
66+ }
11467};
11568} // anonymous namespace
11669
0 commit comments