Skip to content

Commit 8ee32c7

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (#148415)
Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of `IRRewrite` objects. Therefore, metadata about unresolved materializations can no longer be retrieved from `UnresolvedMaterializationRewrite` objects. This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.
1 parent cbdc185 commit 8ee32c7

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -789,26 +789,13 @@ enum MaterializationKind {
789789
Source
790790
};
791791

792-
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793-
/// op. Unresolved materializations are erased at the end of the dialect
794-
/// conversion.
795-
class UnresolvedMaterializationRewrite : public OperationRewrite {
792+
/// Helper class that stores metadata about an unresolved materialization.
793+
class UnresolvedMaterializationInfo {
796794
public:
797-
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
798-
UnrealizedConversionCastOp op,
799-
const TypeConverter *converter,
800-
MaterializationKind kind, Type originalType,
801-
ValueVector mappedValues);
802-
803-
static bool classof(const IRRewrite *rewrite) {
804-
return rewrite->getKind() == Kind::UnresolvedMaterialization;
805-
}
806-
807-
void rollback() override;
808-
809-
UnrealizedConversionCastOp getOperation() const {
810-
return cast<UnrealizedConversionCastOp>(op);
811-
}
795+
UnresolvedMaterializationInfo() = default;
796+
UnresolvedMaterializationInfo(const TypeConverter *converter,
797+
MaterializationKind kind, Type originalType)
798+
: converterAndKind(converter, kind), originalType(originalType) {}
812799

813800
/// Return the type converter of this materialization (which may be null).
814801
const TypeConverter *getConverter() const {
@@ -832,7 +819,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832819
/// The original type of the SSA value. Only used for target
833820
/// materializations.
834821
Type originalType;
822+
};
823+
824+
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825+
/// op. Unresolved materializations fold away or are replaced with
826+
/// source/target materializations at the end of the dialect conversion.
827+
class UnresolvedMaterializationRewrite : public OperationRewrite {
828+
public:
829+
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
830+
UnrealizedConversionCastOp op,
831+
ValueVector mappedValues)
832+
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833+
mappedValues(std::move(mappedValues)) {}
834+
835+
static bool classof(const IRRewrite *rewrite) {
836+
return rewrite->getKind() == Kind::UnresolvedMaterialization;
837+
}
838+
839+
void rollback() override;
835840

841+
UnrealizedConversionCastOp getOperation() const {
842+
return cast<UnrealizedConversionCastOp>(op);
843+
}
844+
845+
private:
836846
/// The values in the conversion value mapping that are being replaced by the
837847
/// results of this unresolved materialization.
838848
ValueVector mappedValues;
@@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10881098
/// by the current pattern.
10891099
SetVector<Block *> patternInsertedBlocks;
10901100

1091-
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092-
/// to the corresponding rewrite objects.
1093-
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1101+
/// A mapping for looking up metadata of unresolved materializations.
1102+
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
10941103
unresolvedMaterializations;
10951104

10961105
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
12101219
op->erase();
12111220
}
12121221

1213-
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1214-
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215-
const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216-
ValueVector mappedValues)
1217-
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218-
converterAndKind(converter, kind), originalType(originalType),
1219-
mappedValues(std::move(mappedValues)) {
1220-
assert((!originalType || kind == MaterializationKind::Target) &&
1221-
"original type is valid only for target materializations");
1222-
rewriterImpl.unresolvedMaterializations[op] = this;
1223-
}
1224-
12251222
void UnresolvedMaterializationRewrite::rollback() {
12261223
if (!mappedValues.empty())
12271224
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15101507
mapping.map(valuesToMap, convertOp.getResults());
15111508
if (castOp)
15121509
*castOp = convertOp;
1513-
appendRewrite<UnresolvedMaterializationRewrite>(
1514-
convertOp, converter, kind, originalType, std::move(valuesToMap));
1510+
unresolvedMaterializations[convertOp] =
1511+
UnresolvedMaterializationInfo(converter, kind, originalType);
1512+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1513+
std::move(valuesToMap));
15151514
return convertOp.getResults();
15161515
}
15171516

@@ -2678,21 +2677,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
26782677

26792678
static LogicalResult
26802679
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2681-
UnresolvedMaterializationRewrite *rewrite) {
2682-
UnrealizedConversionCastOp op = rewrite->getOperation();
2680+
UnrealizedConversionCastOp op,
2681+
const UnresolvedMaterializationInfo &info) {
26832682
assert(!op.use_empty() &&
26842683
"expected that dead materializations have already been DCE'd");
26852684
Operation::operand_range inputOperands = op.getOperands();
26862685

26872686
// Try to materialize the conversion.
2688-
if (const TypeConverter *converter = rewrite->getConverter()) {
2687+
if (const TypeConverter *converter = info.getConverter()) {
26892688
rewriter.setInsertionPoint(op);
26902689
SmallVector<Value> newMaterialization;
2691-
switch (rewrite->getMaterializationKind()) {
2690+
switch (info.getMaterializationKind()) {
26922691
case MaterializationKind::Target:
26932692
newMaterialization = converter->materializeTargetConversion(
26942693
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2695-
rewrite->getOriginalType());
2694+
info.getOriginalType());
26962695
break;
26972696
case MaterializationKind::Source:
26982697
assert(op->getNumResults() == 1 && "expected single result");
@@ -2767,7 +2766,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27672766

27682767
// Gather all unresolved materializations.
27692768
SmallVector<UnrealizedConversionCastOp> allCastOps;
2770-
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2769+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
27712770
&materializations = rewriterImpl.unresolvedMaterializations;
27722771
for (auto it : materializations)
27732772
allCastOps.push_back(it.first);
@@ -2784,7 +2783,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27842783
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
27852784
auto it = materializations.find(castOp);
27862785
assert(it != materializations.end() && "inconsistent state");
2787-
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2786+
if (failed(
2787+
legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
27882788
return failure();
27892789
}
27902790
}

0 commit comments

Comments
 (0)