-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately #148415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately #148415
Conversation
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesStore metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of This commit also removes a pointer indirection and may slightly improve the performance of the existing driver. Full diff: https://github.com/llvm/llvm-project/pull/148415.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..42abc152981e6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,38 +789,22 @@ enum MaterializationKind {
Source
};
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
public:
- UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- UnrealizedConversionCastOp op,
- const TypeConverter *converter,
- MaterializationKind kind, Type originalType,
- ValueVector mappedValues);
-
- static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::UnresolvedMaterialization;
- }
-
- void rollback() override;
-
- UnrealizedConversionCastOp getOperation() const {
- return cast<UnrealizedConversionCastOp>(op);
- }
+ UnresolvedMaterializationInfo() = default;
+ UnresolvedMaterializationInfo(const TypeConverter *converter,
+ MaterializationKind kind, Type originalType)
+ : converterAndKind(converter, kind), originalType(originalType) {}
- /// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
return converterAndKind.getPointer();
}
- /// Return the kind of this materialization.
MaterializationKind getMaterializationKind() const {
return converterAndKind.getInt();
}
- /// Return the original type of the SSA value.
Type getOriginalType() const { return originalType; }
private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op,
+ ValueVector mappedValues)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ mappedValues(std::move(mappedValues)) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ void rollback() override;
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+private:
/// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
- /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
- /// to the corresponding rewrite objects.
- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ /// A mapping for looking up metadata of unresolved materializations.
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind, Type originalType,
- ValueVector mappedValues)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), originalType(originalType),
- mappedValues(std::move(mappedValues)) {
- assert((!originalType || kind == MaterializationKind::Target) &&
- "original type is valid only for target materializations");
- rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
void UnresolvedMaterializationRewrite::rollback() {
if (!mappedValues.empty())
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
- appendRewrite<UnresolvedMaterializationRewrite>(
- convertOp, converter, kind, originalType, std::move(valuesToMap));
+ unresolvedMaterializations[convertOp] =
+ UnresolvedMaterializationInfo(converter, kind, originalType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
return convertOp.getResults();
}
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
static LogicalResult
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
- UnresolvedMaterializationRewrite *rewrite) {
- UnrealizedConversionCastOp op = rewrite->getOperation();
+ UnrealizedConversionCastOp op,
+ const UnresolvedMaterializationInfo &info) {
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
// Try to materialize the conversion.
- if (const TypeConverter *converter = rewrite->getConverter()) {
+ if (const TypeConverter *converter = info.getConverter()) {
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
- switch (rewrite->getMaterializationKind()) {
+ switch (info.getMaterializationKind()) {
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
- rewrite->getOriginalType());
+ info.getOriginalType());
break;
case MaterializationKind::Source:
assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations)
allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+ if (failed(
+ legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
return failure();
}
}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesStore metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of This commit also removes a pointer indirection and may slightly improve the performance of the existing driver. Full diff: https://github.com/llvm/llvm-project/pull/148415.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..42abc152981e6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,38 +789,22 @@ enum MaterializationKind {
Source
};
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
public:
- UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- UnrealizedConversionCastOp op,
- const TypeConverter *converter,
- MaterializationKind kind, Type originalType,
- ValueVector mappedValues);
-
- static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::UnresolvedMaterialization;
- }
-
- void rollback() override;
-
- UnrealizedConversionCastOp getOperation() const {
- return cast<UnrealizedConversionCastOp>(op);
- }
+ UnresolvedMaterializationInfo() = default;
+ UnresolvedMaterializationInfo(const TypeConverter *converter,
+ MaterializationKind kind, Type originalType)
+ : converterAndKind(converter, kind), originalType(originalType) {}
- /// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
return converterAndKind.getPointer();
}
- /// Return the kind of this materialization.
MaterializationKind getMaterializationKind() const {
return converterAndKind.getInt();
}
- /// Return the original type of the SSA value.
Type getOriginalType() const { return originalType; }
private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op,
+ ValueVector mappedValues)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ mappedValues(std::move(mappedValues)) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ void rollback() override;
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+private:
/// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
- /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
- /// to the corresponding rewrite objects.
- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ /// A mapping for looking up metadata of unresolved materializations.
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind, Type originalType,
- ValueVector mappedValues)
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), originalType(originalType),
- mappedValues(std::move(mappedValues)) {
- assert((!originalType || kind == MaterializationKind::Target) &&
- "original type is valid only for target materializations");
- rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
void UnresolvedMaterializationRewrite::rollback() {
if (!mappedValues.empty())
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
- appendRewrite<UnresolvedMaterializationRewrite>(
- convertOp, converter, kind, originalType, std::move(valuesToMap));
+ unresolvedMaterializations[convertOp] =
+ UnresolvedMaterializationInfo(converter, kind, originalType);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
return convertOp.getResults();
}
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
static LogicalResult
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
- UnresolvedMaterializationRewrite *rewrite) {
- UnrealizedConversionCastOp op = rewrite->getOperation();
+ UnrealizedConversionCastOp op,
+ const UnresolvedMaterializationInfo &info) {
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
// Try to materialize the conversion.
- if (const TypeConverter *converter = rewrite->getConverter()) {
+ if (const TypeConverter *converter = info.getConverter()) {
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
- switch (rewrite->getMaterializationKind()) {
+ switch (info.getMaterializationKind()) {
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
- rewrite->getOriginalType());
+ info.getOriginalType());
break;
case MaterializationKind::Source:
assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations)
allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+ if (failed(
+ legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
return failure();
}
}
|
7cedbb9
to
2ce59d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of
IRRewrite
objects. Therefore, metadata about unresolved materializations can no longer be retrieved fromUnresolvedMaterializationRewrite
objects.This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.