@@ -551,10 +551,9 @@ enum class Conv1DOpOrder {
551
551
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552
552
};
553
553
554
- // / Helper data structure to represent the result of vectorization for a single
555
- // / operation. In certain specific cases, like terminators, we do not want to
556
- // / propagate.
557
- enum VectorizationHookStatus {
554
+ // / Helper data structure to represent the result of vectorization.
555
+ // / In certain specific cases, like terminators, we do not want to propagate/
556
+ enum VectorizationStatus {
558
557
// / Op failed to vectorize.
559
558
Failure = 0 ,
560
559
// / Op vectorized and custom function took care of replacement logic
@@ -565,12 +564,9 @@ enum VectorizationHookStatus {
565
564
// TODO: support values if Op vectorized to Many-Ops whose results we need to
566
565
// aggregate for replacement.
567
566
};
568
- // / VectorizationHookResult contains the vectorized op returned from a
569
- // / CustomVectorizationHook. This is an internal implementation detail of
570
- // / linalg vectorization, not to be confused with VectorizationResult.
571
- struct VectorizationHookResult {
567
+ struct VectorizationResult {
572
568
// / Return status from vectorizing the current op.
573
- enum VectorizationHookStatus status = VectorizationHookStatus ::Failure;
569
+ enum VectorizationStatus status = VectorizationStatus ::Failure;
574
570
// / New vectorized operation to replace the current op.
575
571
// / Replacement behavior is specified by `status`.
576
572
Operation *newOp;
@@ -732,22 +728,22 @@ using CustomVectorizationPrecondition =
732
728
// assuming all its vectorized operands are already in the IRMapping.
733
729
// Return nullptr if the Operation cannot be vectorized.
734
730
using CustomVectorizationHook =
735
- std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
731
+ std::function<VectorizationResult (Operation *, const IRMapping &)>;
736
732
737
733
// / Helper function to vectorize the terminator of a `linalgOp`. New result
738
734
// / vector values are appended to `newResults`. Return
739
- // / VectorizationHookStatus ::NoReplace to signal the vectorization algorithm
740
- // / that it should not try to map produced operations and instead return the
741
- // / results using the `newResults` vector making them available to the
742
- // / vectorization algorithm for RAUW. This function is meant to be used as a
735
+ // / VectorizationStatus ::NoReplace to signal the vectorization algorithm that it
736
+ // / should not try to map produced operations and instead return the results
737
+ // / using the `newResults` vector making them available to the vectorization
738
+ // / algorithm for RAUW. This function is meant to be used as a
743
739
// / CustomVectorizationHook.
744
- static VectorizationHookResult
740
+ static VectorizationResult
745
741
vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
746
742
const IRMapping &bvm, VectorizationState &state,
747
743
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748
744
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749
745
if (!yieldOp)
750
- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
746
+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
751
747
for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
752
748
// TODO: Scan for an opportunity for reuse.
753
749
// TODO: use a map.
@@ -759,20 +755,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
759
755
newResults.push_back (newResult);
760
756
}
761
757
762
- return VectorizationHookResult{VectorizationHookStatus ::NoReplace, nullptr };
758
+ return VectorizationResult{VectorizationStatus ::NoReplace, nullptr };
763
759
}
764
760
765
761
// / Helper function to vectorize the index operations of a `linalgOp`. Return
766
- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
762
+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
767
763
// / should map the produced operations. This function is meant to be used as a
768
764
// / CustomVectorizationHook.
769
- static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
770
- VectorizationState &state,
771
- Operation *op,
772
- LinalgOp linalgOp) {
765
+ static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
766
+ VectorizationState &state,
767
+ Operation *op,
768
+ LinalgOp linalgOp) {
773
769
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774
770
if (!indexOp)
775
- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
771
+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
776
772
auto loc = indexOp.getLoc ();
777
773
// Compute the static loop sizes of the index op.
778
774
ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -786,7 +782,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
786
782
// dimension of the iteration space since the vectorization algorithm in this
787
783
// case can handle the broadcast.
788
784
if (dim == targetShape.size () - 1 )
789
- return VectorizationHookResult{VectorizationHookStatus ::NewOp, indexSteps};
785
+ return VectorizationResult{VectorizationStatus ::NewOp, indexSteps};
790
786
// Otherwise permute the targetShape to move the index dimension last,
791
787
// broadcast the one-dimensional index vector to the permuted shape, and
792
788
// finally transpose the broadcasted index vector to undo the permutation.
@@ -804,7 +800,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
804
800
std::swap (transposition.back (), transposition[dim]);
805
801
auto transposeOp =
806
802
rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
807
- return VectorizationHookResult{VectorizationHookStatus ::NewOp, transposeOp};
803
+ return VectorizationResult{VectorizationStatus ::NewOp, transposeOp};
808
804
}
809
805
810
806
// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,15 +1096,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1100
1096
}
1101
1097
1102
1098
// / Helper function to vectorize the tensor.extract operations. Returns
1103
- // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
1099
+ // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
1104
1100
// / should map the produced operations. This function is meant to be used as a
1105
1101
// / CustomVectorizationHook.
1106
- static VectorizationHookResult
1102
+ static VectorizationResult
1107
1103
vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
1108
1104
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1109
1105
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1110
1106
if (!extractOp)
1111
- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1107
+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1112
1108
auto loc = extractOp.getLoc ();
1113
1109
1114
1110
// Compute the static loop sizes of the extract op.
@@ -1140,7 +1136,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1140
1136
gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
1141
1137
1142
1138
LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1143
- return VectorizationHookResult{VectorizationHookStatus ::NewOp, gatherOp};
1139
+ return VectorizationResult{VectorizationStatus ::NewOp, gatherOp};
1144
1140
}
1145
1141
1146
1142
// 2. Handle:
@@ -1204,8 +1200,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1204
1200
mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
1205
1201
1206
1202
LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1207
- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1208
- maskedReadOp};
1203
+ return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1209
1204
}
1210
1205
1211
1206
// 2b. Handle contiguous access.
@@ -1231,8 +1226,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1231
1226
inBounds);
1232
1227
1233
1228
LDBG (" Vectorised as contiguous load: " << extractOp);
1234
- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1235
- transferReadOp};
1229
+ return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1236
1230
}
1237
1231
1238
1232
// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1272,9 +1266,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1272
1266
// / This function assumes all operands of `op` have been vectorized and are in
1273
1267
// / the `bvm` mapping. As a consequence, this function is meant to be called on
1274
1268
// / a topologically-sorted list of ops.
1275
- // / This function does not update `bvm` but returns a VectorizationHookStatus
1276
- // / that instructs the caller what `bvm` update needs to occur.
1277
- static VectorizationHookResult
1269
+ // / This function does not update `bvm` but returns a VectorizationStatus that
1270
+ // / instructs the caller what `bvm` update needs to occur.
1271
+ static VectorizationResult
1278
1272
vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
1279
1273
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1280
1274
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1283,8 +1277,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1283
1277
// 1. Try to apply any CustomVectorizationHook.
1284
1278
if (!customVectorizationHooks.empty ()) {
1285
1279
for (auto &customFunc : customVectorizationHooks) {
1286
- VectorizationHookResult result = customFunc (op, bvm);
1287
- if (result.status == VectorizationHookStatus ::Failure)
1280
+ VectorizationResult result = customFunc (op, bvm);
1281
+ if (result.status == VectorizationStatus ::Failure)
1288
1282
continue ;
1289
1283
return result;
1290
1284
}
@@ -1293,12 +1287,11 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1293
1287
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
1294
1288
// Clone so that the constant is not confined to the linalgOp block .
1295
1289
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1296
- return VectorizationHookResult{VectorizationHookStatus::NewOp,
1297
- rewriter.clone (*op)};
1290
+ return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
1298
1291
1299
1292
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
1300
1293
if (!OpTrait::hasElementwiseMappableTraits (op))
1301
- return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
1294
+ return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1302
1295
1303
1296
// 4 . Check if the operation is a reduction.
1304
1297
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1321,7 +1314,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1321
1314
reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
1322
1315
reductionOperands[0 ].second , bvm);
1323
1316
if (reduceOp)
1324
- return VectorizationHookResult{VectorizationHookStatus ::NewOp, reduceOp};
1317
+ return VectorizationResult{VectorizationStatus ::NewOp, reduceOp};
1325
1318
}
1326
1319
1327
1320
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1361,8 +1354,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1361
1354
: resultType);
1362
1355
}
1363
1356
// d. Build and return the new op.
1364
- return VectorizationHookResult {
1365
- VectorizationHookStatus ::NewOp,
1357
+ return VectorizationResult {
1358
+ VectorizationStatus ::NewOp,
1366
1359
rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
1367
1360
resultTypes, op->getAttrs ())};
1368
1361
}
@@ -1466,34 +1459,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1466
1459
SmallVector<CustomVectorizationHook> hooks;
1467
1460
// 4a. Register CustomVectorizationHook for yieldOp.
1468
1461
CustomVectorizationHook vectorizeYield =
1469
- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1462
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1470
1463
return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
1471
1464
};
1472
1465
hooks.push_back (vectorizeYield);
1473
1466
1474
1467
// 4b. Register CustomVectorizationHook for indexOp.
1475
1468
CustomVectorizationHook vectorizeIndex =
1476
- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1469
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1477
1470
return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
1478
1471
};
1479
1472
hooks.push_back (vectorizeIndex);
1480
1473
1481
1474
// 4c. Register CustomVectorizationHook for extractOp.
1482
1475
CustomVectorizationHook vectorizeExtract =
1483
- [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1476
+ [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1484
1477
return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
1485
1478
};
1486
1479
hooks.push_back (vectorizeExtract);
1487
1480
1488
1481
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1489
1482
for (Operation &op : block->getOperations ()) {
1490
- VectorizationHookResult result =
1483
+ VectorizationResult result =
1491
1484
vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
1492
- if (result.status == VectorizationHookStatus ::Failure) {
1485
+ if (result.status == VectorizationStatus ::Failure) {
1493
1486
LDBG (" failed to vectorize: " << op << " \n " );
1494
1487
return failure ();
1495
1488
}
1496
- if (result.status == VectorizationHookStatus ::NewOp) {
1489
+ if (result.status == VectorizationStatus ::NewOp) {
1497
1490
Operation *maybeMaskedOp =
1498
1491
state.maskOperation (rewriter, result.newOp , linalgOp);
1499
1492
LDBG (" New vector op: " << *maybeMaskedOp << " \n " );
@@ -2530,11 +2523,17 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2530
2523
tensor::InsertSliceOp>(op);
2531
2524
}
2532
2525
2533
- FailureOr<VectorizationResult>
2534
- mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2535
- ArrayRef<int64_t > inputVectorSizes,
2536
- ArrayRef<bool > inputScalableVecDims,
2537
- bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2526
+ // / Emit a suitable vector form for an operation. If provided,
2527
+ // / `inputVectorSizes` are used to vectorize this operation.
2528
+ // / `inputVectorSizes` must match the rank of the iteration space of the
2529
+ // / operation and the input vector sizes must be greater than or equal to
2530
+ // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2531
+ // / also allows the vectorization of operations with dynamic shapes.
2532
+ LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2533
+ ArrayRef<int64_t > inputVectorSizes,
2534
+ ArrayRef<bool > inputScalableVecDims,
2535
+ bool vectorizeNDExtract,
2536
+ bool flatten1DDepthwiseConv) {
2538
2537
LDBG (" Attempting to vectorize:\n " << *op << " \n " );
2539
2538
LDBG (" Input vector sizes: " );
2540
2539
LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2616,7 +2615,12 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2616
2615
return failure ();
2617
2616
}
2618
2617
2619
- return VectorizationResult{results};
2618
+ if (!results.empty ())
2619
+ rewriter.replaceOp (op, results);
2620
+ else
2621
+ rewriter.eraseOp (op);
2622
+
2623
+ return success ();
2620
2624
}
2621
2625
2622
2626
LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments