Skip to content

Commit 0dc523c

Browse files
Groverksslialan
authored andcommitted
Revert "[mlir] Return vectorized values instead of replacing (llvm#144158)"
This reverts commit 4d21da0.
1 parent a611aae commit 0dc523c

File tree

3 files changed

+78
-86
lines changed

3 files changed

+78
-86
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -855,23 +855,17 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
855855
/// to work (these are checked by the vectorizer itself).
856856
bool hasVectorizationImpl(Operation *);
857857

858-
/// Transformation information returned after vectorizing.
859-
struct VectorizationResult {
860-
/// Results of the vectorization transform to replace the original operation.
861-
SmallVector<Value> replacements;
862-
};
863-
/// Returns a `VectorizationResult` containing the results of the vectorized op,
864-
/// or failure if the transformation fails. If provided, `inputVectorSizes` are
865-
/// used to vectorize this operation. `inputVectorSizes` must match the rank of
866-
/// the iteration space of the operation and the input vector sizes must be
867-
/// greater than or equal to their counterpart iteration space sizes, if static.
868-
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
869-
/// shapes.
870-
FailureOr<VectorizationResult>
871-
vectorize(RewriterBase &rewriter, Operation *op,
872-
ArrayRef<int64_t> inputVectorSizes = {},
873-
ArrayRef<bool> inputScalableVecDims = {},
874-
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
858+
/// Emit a suitable vector form for an operation. If provided,
859+
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
860+
/// must match the rank of the iteration space of the operation and the sizes
861+
/// must be smaller or equal than their counterpart interation space sizes, if
862+
/// static. `inputVectorShapes` also allows the vectorization of operations with
863+
/// dynamic shapes.
864+
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
865+
ArrayRef<int64_t> inputVectorSizes = {},
866+
ArrayRef<bool> inputScalableVecDims = {},
867+
bool vectorizeNDExtract = false,
868+
bool flatten1DDepthwiseConv = false);
875869

876870
/// Emit a suitable vector form for a Copy op with fully static shape.
877871
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3823,14 +3823,9 @@ struct VectorizationPattern : public RewritePattern {
38233823
if (!linalg::hasVectorizationImpl(op))
38243824
return rewriter.notifyMatchFailure(op,
38253825
"Unsupported Op, cannot vectorize");
3826-
FailureOr<VectorizationResult> vectorResults =
3827-
vectorize(rewriter, op, /*inputVectorSizes=*/{},
3828-
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3829-
flatten1DDepthwiseConv);
3830-
if (failed(vectorResults))
3831-
return failure();
3832-
rewriter.replaceOp(op, vectorResults->replacements);
3833-
return success();
3826+
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3827+
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
3828+
flatten1DDepthwiseConv);
38343829
}
38353830

38363831
private:
@@ -3919,14 +3914,13 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39193914
return mlir::emitSilenceableFailure(target->getLoc())
39203915
<< "Unsupported Op, cannot vectorize";
39213916
}
3922-
FailureOr<VectorizationResult> vectorResults =
3923-
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3924-
getVectorizeNdExtract().value_or(false));
3925-
if (failed(vectorResults)) {
3917+
3918+
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3919+
getScalableSizes(),
3920+
getVectorizeNdExtract().value_or(false)))) {
39263921
return mlir::emitSilenceableFailure(target->getLoc())
39273922
<< "Attempted to vectorize, but failed";
39283923
}
3929-
rewriter.replaceOp(target, vectorResults->replacements);
39303924
}
39313925

39323926
return DiagnosedSilenceableFailure::success();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,9 @@ enum class Conv1DOpOrder {
551551
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552
};
553553

554-
/// Helper data structure to represent the result of vectorization for a single
555-
/// operation. In certain specific cases, like terminators, we do not want to
556-
/// propagate.
557-
enum VectorizationHookStatus {
554+
/// Helper data structure to represent the result of vectorization.
555+
/// In certain specific cases, like terminators, we do not want to propagate/
556+
enum VectorizationStatus {
558557
/// Op failed to vectorize.
559558
Failure = 0,
560559
/// Op vectorized and custom function took care of replacement logic
@@ -565,12 +564,9 @@ enum VectorizationHookStatus {
565564
// TODO: support values if Op vectorized to Many-Ops whose results we need to
566565
// aggregate for replacement.
567566
};
568-
/// VectorizationHookResult contains the vectorized op returned from a
569-
/// CustomVectorizationHook. This is an internal implementation detail of
570-
/// linalg vectorization, not to be confused with VectorizationResult.
571-
struct VectorizationHookResult {
567+
struct VectorizationResult {
572568
/// Return status from vectorizing the current op.
573-
enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
569+
enum VectorizationStatus status = VectorizationStatus::Failure;
574570
/// New vectorized operation to replace the current op.
575571
/// Replacement behavior is specified by `status`.
576572
Operation *newOp;
@@ -732,22 +728,22 @@ using CustomVectorizationPrecondition =
732728
// assuming all its vectorized operands are already in the IRMapping.
733729
// Return nullptr if the Operation cannot be vectorized.
734730
using CustomVectorizationHook =
735-
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
731+
std::function<VectorizationResult(Operation *, const IRMapping &)>;
736732

737733
/// Helper function to vectorize the terminator of a `linalgOp`. New result
738734
/// vector values are appended to `newResults`. Return
739-
/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
740-
/// that it should not try to map produced operations and instead return the
741-
/// results using the `newResults` vector making them available to the
742-
/// vectorization algorithm for RAUW. This function is meant to be used as a
735+
/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
736+
/// should not try to map produced operations and instead return the results
737+
/// using the `newResults` vector making them available to the vectorization
738+
/// algorithm for RAUW. This function is meant to be used as a
743739
/// CustomVectorizationHook.
744-
static VectorizationHookResult
740+
static VectorizationResult
745741
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
746742
const IRMapping &bvm, VectorizationState &state,
747743
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
748744
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
749745
if (!yieldOp)
750-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
746+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
751747
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
752748
// TODO: Scan for an opportunity for reuse.
753749
// TODO: use a map.
@@ -759,20 +755,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
759755
newResults.push_back(newResult);
760756
}
761757

762-
return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
758+
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
763759
}
764760

765761
/// Helper function to vectorize the index operations of a `linalgOp`. Return
766-
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
762+
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
767763
/// should map the produced operations. This function is meant to be used as a
768764
/// CustomVectorizationHook.
769-
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
770-
VectorizationState &state,
771-
Operation *op,
772-
LinalgOp linalgOp) {
765+
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
766+
VectorizationState &state,
767+
Operation *op,
768+
LinalgOp linalgOp) {
773769
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
774770
if (!indexOp)
775-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
771+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
776772
auto loc = indexOp.getLoc();
777773
// Compute the static loop sizes of the index op.
778774
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
@@ -786,7 +782,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
786782
// dimension of the iteration space since the vectorization algorithm in this
787783
// case can handle the broadcast.
788784
if (dim == targetShape.size() - 1)
789-
return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
785+
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
790786
// Otherwise permute the targetShape to move the index dimension last,
791787
// broadcast the one-dimensional index vector to the permuted shape, and
792788
// finally transpose the broadcasted index vector to undo the permutation.
@@ -804,7 +800,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
804800
std::swap(transposition.back(), transposition[dim]);
805801
auto transposeOp =
806802
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
807-
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
803+
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
808804
}
809805

810806
/// Helper function to check if the tensor.extract can be vectorized by the
@@ -1100,15 +1096,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
11001096
}
11011097

11021098
/// Helper function to vectorize the tensor.extract operations. Returns
1103-
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1099+
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
11041100
/// should map the produced operations. This function is meant to be used as a
11051101
/// CustomVectorizationHook.
1106-
static VectorizationHookResult
1102+
static VectorizationResult
11071103
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11081104
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11091105
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11101106
if (!extractOp)
1111-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
1107+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
11121108
auto loc = extractOp.getLoc();
11131109

11141110
// Compute the static loop sizes of the extract op.
@@ -1140,7 +1136,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11401136
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
11411137

11421138
LDBG("Vectorised as gather load: " << extractOp << "\n");
1143-
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
1139+
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
11441140
}
11451141

11461142
// 2. Handle:
@@ -1204,8 +1200,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12041200
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
12051201

12061202
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1207-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1208-
maskedReadOp};
1203+
return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
12091204
}
12101205

12111206
// 2b. Handle contiguous access.
@@ -1231,8 +1226,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12311226
inBounds);
12321227

12331228
LDBG("Vectorised as contiguous load: " << extractOp);
1234-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1235-
transferReadOp};
1229+
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
12361230
}
12371231

12381232
/// Emit reduction operations if the shapes of the value to reduce is different
@@ -1272,9 +1266,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12721266
/// This function assumes all operands of `op` have been vectorized and are in
12731267
/// the `bvm` mapping. As a consequence, this function is meant to be called on
12741268
/// a topologically-sorted list of ops.
1275-
/// This function does not update `bvm` but returns a VectorizationHookStatus
1276-
/// that instructs the caller what `bvm` update needs to occur.
1277-
static VectorizationHookResult
1269+
/// This function does not update `bvm` but returns a VectorizationStatus that
1270+
/// instructs the caller what `bvm` update needs to occur.
1271+
static VectorizationResult
12781272
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12791273
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12801274
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1283,8 +1277,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12831277
// 1. Try to apply any CustomVectorizationHook.
12841278
if (!customVectorizationHooks.empty()) {
12851279
for (auto &customFunc : customVectorizationHooks) {
1286-
VectorizationHookResult result = customFunc(op, bvm);
1287-
if (result.status == VectorizationHookStatus::Failure)
1280+
VectorizationResult result = customFunc(op, bvm);
1281+
if (result.status == VectorizationStatus::Failure)
12881282
continue;
12891283
return result;
12901284
}
@@ -1293,12 +1287,11 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12931287
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
12941288
// Clone so that the constant is not confined to the linalgOp block .
12951289
if (isa<arith::ConstantOp, func::ConstantOp>(op))
1296-
return VectorizationHookResult{VectorizationHookStatus::NewOp,
1297-
rewriter.clone(*op)};
1290+
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
12981291

12991292
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
13001293
if (!OpTrait::hasElementwiseMappableTraits(op))
1301-
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
1294+
return VectorizationResult{VectorizationStatus::Failure, nullptr};
13021295

13031296
// 4 . Check if the operation is a reduction.
13041297
SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1321,7 +1314,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13211314
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
13221315
reductionOperands[0].second, bvm);
13231316
if (reduceOp)
1324-
return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
1317+
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
13251318
}
13261319

13271320
// 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1361,8 +1354,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13611354
: resultType);
13621355
}
13631356
// d. Build and return the new op.
1364-
return VectorizationHookResult{
1365-
VectorizationHookStatus::NewOp,
1357+
return VectorizationResult{
1358+
VectorizationStatus::NewOp,
13661359
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
13671360
resultTypes, op->getAttrs())};
13681361
}
@@ -1466,34 +1459,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14661459
SmallVector<CustomVectorizationHook> hooks;
14671460
// 4a. Register CustomVectorizationHook for yieldOp.
14681461
CustomVectorizationHook vectorizeYield =
1469-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1462+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14701463
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
14711464
};
14721465
hooks.push_back(vectorizeYield);
14731466

14741467
// 4b. Register CustomVectorizationHook for indexOp.
14751468
CustomVectorizationHook vectorizeIndex =
1476-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1469+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14771470
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
14781471
};
14791472
hooks.push_back(vectorizeIndex);
14801473

14811474
// 4c. Register CustomVectorizationHook for extractOp.
14821475
CustomVectorizationHook vectorizeExtract =
1483-
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1476+
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
14841477
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
14851478
};
14861479
hooks.push_back(vectorizeExtract);
14871480

14881481
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14891482
for (Operation &op : block->getOperations()) {
1490-
VectorizationHookResult result =
1483+
VectorizationResult result =
14911484
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1492-
if (result.status == VectorizationHookStatus::Failure) {
1485+
if (result.status == VectorizationStatus::Failure) {
14931486
LDBG("failed to vectorize: " << op << "\n");
14941487
return failure();
14951488
}
1496-
if (result.status == VectorizationHookStatus::NewOp) {
1489+
if (result.status == VectorizationStatus::NewOp) {
14971490
Operation *maybeMaskedOp =
14981491
state.maskOperation(rewriter, result.newOp, linalgOp);
14991492
LDBG("New vector op: " << *maybeMaskedOp << "\n");
@@ -2530,11 +2523,17 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25302523
tensor::InsertSliceOp>(op);
25312524
}
25322525

2533-
FailureOr<VectorizationResult>
2534-
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2535-
ArrayRef<int64_t> inputVectorSizes,
2536-
ArrayRef<bool> inputScalableVecDims,
2537-
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2526+
/// Emit a suitable vector form for an operation. If provided,
2527+
/// `inputVectorSizes` are used to vectorize this operation.
2528+
/// `inputVectorSizes` must match the rank of the iteration space of the
2529+
/// operation and the input vector sizes must be greater than or equal to
2530+
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2531+
/// also allows the vectorization of operations with dynamic shapes.
2532+
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2533+
ArrayRef<int64_t> inputVectorSizes,
2534+
ArrayRef<bool> inputScalableVecDims,
2535+
bool vectorizeNDExtract,
2536+
bool flatten1DDepthwiseConv) {
25382537
LDBG("Attempting to vectorize:\n" << *op << "\n");
25392538
LDBG("Input vector sizes: ");
25402539
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2616,7 +2615,12 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26162615
return failure();
26172616
}
26182617

2619-
return VectorizationResult{results};
2618+
if (!results.empty())
2619+
rewriter.replaceOp(op, results);
2620+
else
2621+
rewriter.eraseOp(op);
2622+
2623+
return success();
26202624
}
26212625

26222626
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

0 commit comments

Comments
 (0)