Skip to content

Commit 878d359

Browse files
authored
[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison (#146088)
Context: `vector.transfer_read` always requires a padding value. Most of its builders take no `padding` value and assume the safe value of `0`. However, this should be a conscious choice by the API user, as it makes it easy to introduce bugs. For example, I found several occasions while making this patch that the padding value was not getting propagated (`vector.transfer_read` was transformed into another `vector.transfer_read`). These bugs, were always caused because of constructors that don't require specifying padding. Additionally, using `ub.poison` as a possible default value is better, as it indicates the user "doesn't care" about the actual padding value, forcing users to specify the actual padding semantics they want. With that in mind, this patch changes the builders in `vector.transfer_read` to always having a `std::optional<Value> padding` argument. This argument is never optional, but for convenience users can pass `std::nullopt`, padding the transfer read with `ub.poison`. --------- Signed-off-by: Fabian Mora <[email protected]>
1 parent 6a57af8 commit 878d359

File tree

15 files changed

+107
-78
lines changed

15 files changed

+107
-78
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
154154
Value lhs, Value rhs);
155155

156156
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
157+
158+
/// Creates an `arith.constant` operation with a zero value of type `type`. This
159+
/// method asserts if `type` is invalid for representing zero with
160+
/// `arith.constant`.
161+
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
157162
} // namespace arith
158163
} // namespace mlir
159164

mlir/include/mlir/Dialect/Vector/IR/Vector.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
2121

2222
let useDefaultAttributePrinterParser = 1;
2323
let hasConstantMaterializer = 1;
24-
let dependentDialects = ["arith::ArithDialect"];
24+
let dependentDialects = [
25+
"arith::ArithDialect",
26+
"ub::UBDialect"
27+
];
2528
}
2629

2730
// Base class for Vector dialect ops.

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,30 +1543,29 @@ def Vector_TransferReadOp :
15431543
}];
15441544

15451545
let builders = [
1546-
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
1546+
/// 1. Builder that sets padding to `padding` or poison if not provided and
1547+
/// an empty mask (variant with attrs).
15471548
OpBuilder<(ins "VectorType":$vectorType,
15481549
"Value":$source,
15491550
"ValueRange":$indices,
1551+
"std::optional<Value>":$padding,
15501552
"AffineMapAttr":$permutationMapAttr,
15511553
"ArrayAttr":$inBoundsAttr)>,
1552-
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
1554+
/// 2. Builder that sets padding to `padding` or poison if not provided and
1555+
/// an empty mask (variant without attrs).
15531556
OpBuilder<(ins "VectorType":$vectorType,
15541557
"Value":$source,
15551558
"ValueRange":$indices,
1559+
"std::optional<Value>":$padding,
15561560
"AffineMap":$permutationMap,
15571561
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1558-
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
1562+
/// 3. Builder that sets padding to `padding` or poison if not provided and
1563+
/// permutation map to 'getMinorIdentityMap'.
15591564
OpBuilder<(ins "VectorType":$vectorType,
15601565
"Value":$source,
15611566
"ValueRange":$indices,
1562-
"Value":$padding,
1563-
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1564-
/// 4. Builder that sets padding to zero and permutation map to
1565-
/// 'getMinorIdentityMap'.
1566-
OpBuilder<(ins "VectorType":$vectorType,
1567-
"Value":$source,
1568-
"ValueRange":$indices,
1569-
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
1567+
"std::optional<Value>":$padding,
1568+
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
15701569
];
15711570

15721571
let extraClassDeclaration = [{

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
12571257
LLVM_DEBUG(permutationMap.print(dbgs()));
12581258

12591259
auto transfer = state.builder.create<vector::TransferReadOp>(
1260-
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
1260+
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
1261+
/*padding=*/std::nullopt, permutationMap);
12611262

12621263
// Register replacement for future uses in the scope.
12631264
state.registerOpVectorReplacement(loadOp, transfer);

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,16 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
292292
return false;
293293
}
294294

295+
Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
296+
Type type) {
297+
// TODO: Incorporate this check to `FloatAttr::get*`.
298+
assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
299+
"type doesn't have a zero representation");
300+
TypedAttr zeroAttr = builder.getZeroAttr(type);
301+
assert(zeroAttr && "unsupported type for zero attribute");
302+
return builder.create<arith::ConstantOp>(loc, zeroAttr);
303+
}
304+
295305
//===----------------------------------------------------------------------===//
296306
// AddIOp
297307
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
426426
// Create the new `transfer_read`.
427427
auto newReadOp = rewriter.create<vector::TransferReadOp>(
428428
readOp.getLoc(), collapsedVT, collapsedMem, indices,
429+
readOp.getPadding(),
429430
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
430431

431432
// Cast back to the original vector type.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,14 +1183,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11831183
auto srcRank = extractOp.getTensor().getType().getRank();
11841184
SmallVector<bool> inBounds(dstRank, true);
11851185

1186+
// Get the value to pad transfer reads with 0.
1187+
Value padding =
1188+
arith::getZeroConstant(rewriter, loc, resultType.getElementType());
1189+
11861190
// 2a. Handle scalar broadcast access.
11871191
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
11881192
MLIRContext *ctx = rewriter.getContext();
11891193
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
11901194
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
11911195

11921196
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1193-
loc, resultType, extractOp.getTensor(), transferReadIdxs,
1197+
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
11941198
permutationMap, inBounds);
11951199

11961200
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12271231
}
12281232

12291233
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1230-
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1231-
inBounds);
1234+
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
1235+
permutationMap, inBounds);
12321236

12331237
LDBG("Vectorised as contiguous load: " << extractOp);
12341238
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13841388
/// performed to the maximal common vector size implied by the `linalgOp`
13851389
/// iteration space. This eager broadcasting is introduced in the
13861390
/// permutation_map of the vector.transfer_read operations. The eager
1387-
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
1391+
/// broadcasting makes it trivial to determine where broadcast, transposes and
13881392
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
13891393
/// the absence of good canonicalizations, the amount of work increases.
13901394
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14391443
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
14401444

14411445
Operation *read = rewriter.create<vector::TransferReadOp>(
1442-
loc, readType, opOperand->get(), indices, readMap);
1446+
loc, readType, opOperand->get(), indices,
1447+
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
14431448
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
14441449
Value readValue = read->getResult(0);
14451450

@@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
26412646

26422647
Value readValue = rewriter.create<vector::TransferReadOp>(
26432648
loc, readType, copyOp.getSource(), indices,
2649+
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
26442650
rewriter.getMultiDimIdentityMap(srcType.getRank()));
26452651
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
26462652
readValue =
@@ -3487,15 +3493,18 @@ struct Conv1DGenerator
34873493
SmallVector<Value> resPadding(resShape.size(), zero);
34883494

34893495
// Read the whole lhs, rhs and res in one shot (with zero padding).
3490-
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3491-
lhsPadding);
3496+
Value lhs = rewriter.create<vector::TransferReadOp>(
3497+
loc, lhsType, lhsShaped, lhsPadding,
3498+
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
34923499
// This is needed only for Conv.
34933500
Value rhs = nullptr;
34943501
if (oper == ConvOperationKind::Conv)
3495-
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3496-
rhsPadding);
3497-
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3498-
resPadding);
3502+
rhs = rewriter.create<vector::TransferReadOp>(
3503+
loc, rhsType, rhsShaped, rhsPadding,
3504+
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3505+
Value res = rewriter.create<vector::TransferReadOp>(
3506+
loc, resType, resShaped, resPadding,
3507+
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
34993508

35003509
// The base vectorization case for channeled convolution is input:
35013510
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3751,22 @@ struct Conv1DGenerator
37423751
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
37433752
// 0].
37443753
Value lhs = rewriter.create<vector::TransferReadOp>(
3745-
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3754+
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3755+
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
37463756
auto maybeMaskedLhs = maybeMaskXferOp(
37473757
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
37483758

37493759
// Read rhs slice of size {kw, c} @ [0, 0].
3750-
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3751-
ValueRange{zero, zero});
3760+
Value rhs = rewriter.create<vector::TransferReadOp>(
3761+
loc, rhsType, rhsShaped, ValueRange{zero, zero},
3762+
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
37523763
auto maybeMaskedRhs = maybeMaskXferOp(
37533764
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
37543765

37553766
// Read res slice of size {n, w, c} @ [0, 0, 0].
37563767
Value res = rewriter.create<vector::TransferReadOp>(
3757-
loc, resType, resShaped, ValueRange{zero, zero, zero});
3768+
loc, resType, resShaped, ValueRange{zero, zero, zero},
3769+
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
37583770
auto maybeMaskedRes = maybeMaskXferOp(
37593771
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
37603772

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
42614261
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
42624262
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42634263
VectorType vectorType, Value source,
4264-
ValueRange indices, AffineMapAttr permutationMapAttr,
4264+
ValueRange indices, std::optional<Value> padding,
4265+
AffineMapAttr permutationMapAttr,
42654266
/*optional*/ ArrayAttr inBoundsAttr) {
4267+
42664268
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4267-
Value padding = builder.create<arith::ConstantOp>(
4268-
result.location, elemType, builder.getZeroAttr(elemType));
4269+
if (!padding)
4270+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
42694271
build(builder, result, vectorType, source, indices, permutationMapAttr,
4270-
padding, /*mask=*/Value(), inBoundsAttr);
4272+
*padding, /*mask=*/Value(), inBoundsAttr);
42714273
}
42724274

42734275
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
42744276
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42754277
VectorType vectorType, Value source,
4276-
ValueRange indices, AffineMap permutationMap,
4278+
ValueRange indices, std::optional<Value> padding,
4279+
AffineMap permutationMap,
42774280
std::optional<ArrayRef<bool>> inBounds) {
42784281
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
42794282
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
42804283
? builder.getBoolArrayAttr(inBounds.value())
42814284
: builder.getBoolArrayAttr(
42824285
SmallVector<bool>(vectorType.getRank(), false));
4283-
build(builder, result, vectorType, source, indices, permutationMapAttr,
4284-
inBoundsAttr);
4286+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4287+
if (!padding)
4288+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
4289+
build(builder, result, vectorType, source, indices, *padding,
4290+
permutationMapAttr, inBoundsAttr);
42854291
}
42864292

42874293
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
42884294
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42894295
VectorType vectorType, Value source,
4290-
ValueRange indices, Value padding,
4296+
ValueRange indices, std::optional<Value> padding,
42914297
std::optional<ArrayRef<bool>> inBounds) {
42924298
AffineMap permutationMap = getTransferMinorIdentityMap(
42934299
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
42964302
? builder.getBoolArrayAttr(inBounds.value())
42974303
: builder.getBoolArrayAttr(
42984304
SmallVector<bool>(vectorType.getRank(), false));
4305+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4306+
if (!padding)
4307+
padding = builder.create<ub::PoisonOp>(result.location, elemType);
42994308
build(builder, result, vectorType, source, indices, permutationMapAttr,
4300-
padding,
4309+
*padding,
43014310
/*mask=*/Value(), inBoundsAttr);
43024311
}
43034312

4304-
/// 4. Builder that sets padding to zero and permutation map to
4305-
/// 'getMinorIdentityMap'.
4306-
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4307-
VectorType vectorType, Value source,
4308-
ValueRange indices,
4309-
std::optional<ArrayRef<bool>> inBounds) {
4310-
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4311-
Value padding = builder.create<arith::ConstantOp>(
4312-
result.location, elemType, builder.getZeroAttr(elemType));
4313-
build(builder, result, vectorType, source, indices, padding, inBounds);
4314-
}
4315-
43164313
template <typename EmitFun>
43174314
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
43184315
EmitFun emitOpError) {

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
173173
}
174174
SmallVector<bool> inBounds(indices.size(), true);
175175
return b.create<vector::TransferReadOp>(
176-
loc, cast<VectorType>(type), buffer, indices,
176+
loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
177177
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178178
}
179179

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern
660660
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
661661
vectorType.getElementType());
662662
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
663-
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
663+
loc, flatVectorType, collapsedSource, collapsedIndices,
664+
transferReadOp.getPadding(), collapsedMap);
664665
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
665666

666667
// 4. Replace the old transfer_read with the new one reading from the

0 commit comments

Comments
 (0)