Skip to content

[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately #148415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025

Conversation

matthias-springer
Copy link
Member

Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of IRRewrite objects. Therefore, metadata about unresolved materializations can no longer be retrieved from UnresolvedMaterializationRewrite objects.

This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jul 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of IRRewrite objects. Therefore, metadata about unresolved materializations can no longer be retrieved from UnresolvedMaterializationRewrite objects.

This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.


Full diff: https://github.com/llvm/llvm-project/pull/148415.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-46)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..42abc152981e6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,38 +789,22 @@ enum MaterializationKind {
   Source
 };
 
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
 public:
-  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                                   UnrealizedConversionCastOp op,
-                                   const TypeConverter *converter,
-                                   MaterializationKind kind, Type originalType,
-                                   ValueVector mappedValues);
-
-  static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::UnresolvedMaterialization;
-  }
-
-  void rollback() override;
-
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
+  UnresolvedMaterializationInfo() = default;
+  UnresolvedMaterializationInfo(const TypeConverter *converter,
+                                MaterializationKind kind, Type originalType)
+      : converterAndKind(converter, kind), originalType(originalType) {}
 
-  /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
   }
 
-  /// Return the kind of this materialization.
   MaterializationKind getMaterializationKind() const {
     return converterAndKind.getInt();
   }
 
-  /// Return the original type of the SSA value.
   Type getOriginalType() const { return originalType; }
 
 private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original type of the SSA value. Only used for target
   /// materializations.
   Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   ValueVector mappedValues)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        mappedValues(std::move(mappedValues)) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  void rollback() override;
 
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+private:
   /// The values in the conversion value mapping that are being replaced by the
   /// results of this unresolved materialization.
   ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// by the current pattern.
   SetVector<Block *> patternInsertedBlocks;
 
-  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
-  /// to the corresponding rewrite objects.
-  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  /// A mapping for looking up metadata of unresolved materializations.
+  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
-    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    ValueVector mappedValues)
-    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind), originalType(originalType),
-      mappedValues(std::move(mappedValues)) {
-  assert((!originalType || kind == MaterializationKind::Target) &&
-         "original type is valid only for target materializations");
-  rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
 void UnresolvedMaterializationRewrite::rollback() {
   if (!mappedValues.empty())
     rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
-  appendRewrite<UnresolvedMaterializationRewrite>(
-      convertOp, converter, kind, originalType, std::move(valuesToMap));
+  unresolvedMaterializations[convertOp] =
+      UnresolvedMaterializationInfo(converter, kind, originalType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+                                                  std::move(valuesToMap));
   return convertOp.getResults();
 }
 
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
 
 static LogicalResult
 legalizeUnresolvedMaterialization(RewriterBase &rewriter,
-                                  UnresolvedMaterializationRewrite *rewrite) {
-  UnrealizedConversionCastOp op = rewrite->getOperation();
+                                  UnrealizedConversionCastOp op,
+                                  const UnresolvedMaterializationInfo &info) {
   assert(!op.use_empty() &&
          "expected that dead materializations have already been DCE'd");
   Operation::operand_range inputOperands = op.getOperands();
 
   // Try to materialize the conversion.
-  if (const TypeConverter *converter = rewrite->getConverter()) {
+  if (const TypeConverter *converter = info.getConverter()) {
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
-    switch (rewrite->getMaterializationKind()) {
+    switch (info.getMaterializationKind()) {
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
-          rewrite->getOriginalType());
+          info.getOriginalType());
       break;
     case MaterializationKind::Source:
       assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       &materializations = rewriterImpl.unresolvedMaterializations;
   for (auto it : materializations)
     allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
       auto it = materializations.find(castOp);
       assert(it != materializations.end() && "inconsistent state");
-      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+      if (failed(
+              legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
         return failure();
     }
   }

@llvmbot
Copy link
Member

llvmbot commented Jul 13, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of IRRewrite objects. Therefore, metadata about unresolved materializations can no longer be retrieved from UnresolvedMaterializationRewrite objects.

This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.


Full diff: https://github.com/llvm/llvm-project/pull/148415.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-46)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 437dbcfea5288..42abc152981e6 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -789,38 +789,22 @@ enum MaterializationKind {
   Source
 };
 
-/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
-/// op. Unresolved materializations are erased at the end of the dialect
-/// conversion.
-class UnresolvedMaterializationRewrite : public OperationRewrite {
+/// Helper class that stores metadata about an unresolved materialization.
+class UnresolvedMaterializationInfo {
 public:
-  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                                   UnrealizedConversionCastOp op,
-                                   const TypeConverter *converter,
-                                   MaterializationKind kind, Type originalType,
-                                   ValueVector mappedValues);
-
-  static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::UnresolvedMaterialization;
-  }
-
-  void rollback() override;
-
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
+  UnresolvedMaterializationInfo() = default;
+  UnresolvedMaterializationInfo(const TypeConverter *converter,
+                                MaterializationKind kind, Type originalType)
+      : converterAndKind(converter, kind), originalType(originalType) {}
 
-  /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
   }
 
-  /// Return the kind of this materialization.
   MaterializationKind getMaterializationKind() const {
     return converterAndKind.getInt();
   }
 
-  /// Return the original type of the SSA value.
   Type getOriginalType() const { return originalType; }
 
 private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original type of the SSA value. Only used for target
   /// materializations.
   Type originalType;
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations fold away or are replaced with
+/// source/target materializations at the end of the dialect conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   ValueVector mappedValues)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        mappedValues(std::move(mappedValues)) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  void rollback() override;
 
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+private:
   /// The values in the conversion value mapping that are being replaced by the
   /// results of this unresolved materialization.
   ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// by the current pattern.
   SetVector<Block *> patternInsertedBlocks;
 
-  /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
-  /// to the corresponding rewrite objects.
-  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  /// A mapping for looking up metadata of unresolved materializations.
+  DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       unresolvedMaterializations;
 
   /// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
   op->erase();
 }
 
-UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
-    ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    ValueVector mappedValues)
-    : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind), originalType(originalType),
-      mappedValues(std::move(mappedValues)) {
-  assert((!originalType || kind == MaterializationKind::Target) &&
-         "original type is valid only for target materializations");
-  rewriterImpl.unresolvedMaterializations[op] = this;
-}
-
 void UnresolvedMaterializationRewrite::rollback() {
   if (!mappedValues.empty())
     rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
-  appendRewrite<UnresolvedMaterializationRewrite>(
-      convertOp, converter, kind, originalType, std::move(valuesToMap));
+  unresolvedMaterializations[convertOp] =
+      UnresolvedMaterializationInfo(converter, kind, originalType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+                                                  std::move(valuesToMap));
   return convertOp.getResults();
 }
 
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
 
 static LogicalResult
 legalizeUnresolvedMaterialization(RewriterBase &rewriter,
-                                  UnresolvedMaterializationRewrite *rewrite) {
-  UnrealizedConversionCastOp op = rewrite->getOperation();
+                                  UnrealizedConversionCastOp op,
+                                  const UnresolvedMaterializationInfo &info) {
   assert(!op.use_empty() &&
          "expected that dead materializations have already been DCE'd");
   Operation::operand_range inputOperands = op.getOperands();
 
   // Try to materialize the conversion.
-  if (const TypeConverter *converter = rewrite->getConverter()) {
+  if (const TypeConverter *converter = info.getConverter()) {
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
-    switch (rewrite->getMaterializationKind()) {
+    switch (info.getMaterializationKind()) {
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
-          rewrite->getOriginalType());
+          info.getOriginalType());
       break;
     case MaterializationKind::Source:
       assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Gather all unresolved materializations.
   SmallVector<UnrealizedConversionCastOp> allCastOps;
-  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
+  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       &materializations = rewriterImpl.unresolvedMaterializations;
   for (auto it : materializations)
     allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
       auto it = materializations.find(castOp);
       assert(it != materializations.end() && "inconsistent state");
-      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+      if (failed(
+              legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
         return failure();
     }
   }

@matthias-springer matthias-springer force-pushed the users/matthias-springer/unresolved_mat_info branch from 7cedbb9 to 2ce59d7 Compare July 13, 2025 08:42
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit 8ee32c7 into main Jul 15, 2025
9 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/unresolved_mat_info branch July 15, 2025 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants