Skip to content

Commit 7cedbb9

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately
1 parent 58c0bd1 commit 7cedbb9

File tree

1 file changed

+43
-46
lines changed

1 file changed

+43
-46
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -789,38 +789,22 @@ enum MaterializationKind {
789789
Source
790790
};
791791

792-
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793-
/// op. Unresolved materializations are erased at the end of the dialect
794-
/// conversion.
795-
class UnresolvedMaterializationRewrite : public OperationRewrite {
792+
/// Helper class that stores metadata about an unresolved materialization.
793+
class UnresolvedMaterializationInfo {
796794
public:
797-
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
798-
UnrealizedConversionCastOp op,
799-
const TypeConverter *converter,
800-
MaterializationKind kind, Type originalType,
801-
ValueVector mappedValues);
802-
803-
static bool classof(const IRRewrite *rewrite) {
804-
return rewrite->getKind() == Kind::UnresolvedMaterialization;
805-
}
806-
807-
void rollback() override;
808-
809-
UnrealizedConversionCastOp getOperation() const {
810-
return cast<UnrealizedConversionCastOp>(op);
811-
}
795+
UnresolvedMaterializationInfo() = default;
796+
UnresolvedMaterializationInfo(const TypeConverter *converter,
797+
MaterializationKind kind, Type originalType)
798+
: converterAndKind(converter, kind), originalType(originalType) {}
812799

813-
/// Return the type converter of this materialization (which may be null).
814800
const TypeConverter *getConverter() const {
815801
return converterAndKind.getPointer();
816802
}
817803

818-
/// Return the kind of this materialization.
819804
MaterializationKind getMaterializationKind() const {
820805
return converterAndKind.getInt();
821806
}
822807

823-
/// Return the original type of the SSA value.
824808
Type getOriginalType() const { return originalType; }
825809

826810
private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832816
/// The original type of the SSA value. Only used for target
833817
/// materializations.
834818
Type originalType;
819+
};
820+
821+
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
822+
/// op. Unresolved materializations fold away or are replaced with
823+
/// source/target materializations at the end of the dialect conversion.
824+
class UnresolvedMaterializationRewrite : public OperationRewrite {
825+
public:
826+
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
827+
UnrealizedConversionCastOp op,
828+
ValueVector mappedValues)
829+
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
830+
mappedValues(std::move(mappedValues)) {}
831+
832+
static bool classof(const IRRewrite *rewrite) {
833+
return rewrite->getKind() == Kind::UnresolvedMaterialization;
834+
}
835+
836+
void rollback() override;
835837

838+
UnrealizedConversionCastOp getOperation() const {
839+
return cast<UnrealizedConversionCastOp>(op);
840+
}
841+
842+
private:
836843
/// The values in the conversion value mapping that are being replaced by the
837844
/// results of this unresolved materialization.
838845
ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10881095
/// by the current pattern.
10891096
SetVector<Block *> patternInsertedBlocks;
10901097

1091-
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092-
/// to the corresponding rewrite objects.
1093-
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1098+
/// A mapping for looking up metadata of unresolved materializations.
1099+
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
10941100
unresolvedMaterializations;
10951101

10961102
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
12101216
op->erase();
12111217
}
12121218

1213-
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1214-
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215-
const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216-
ValueVector mappedValues)
1217-
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218-
converterAndKind(converter, kind), originalType(originalType),
1219-
mappedValues(std::move(mappedValues)) {
1220-
assert((!originalType || kind == MaterializationKind::Target) &&
1221-
"original type is valid only for target materializations");
1222-
rewriterImpl.unresolvedMaterializations[op] = this;
1223-
}
1224-
12251219
void UnresolvedMaterializationRewrite::rollback() {
12261220
if (!mappedValues.empty())
12271221
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15101504
mapping.map(valuesToMap, convertOp.getResults());
15111505
if (castOp)
15121506
*castOp = convertOp;
1513-
appendRewrite<UnresolvedMaterializationRewrite>(
1514-
convertOp, converter, kind, originalType, std::move(valuesToMap));
1507+
unresolvedMaterializations[convertOp] =
1508+
UnresolvedMaterializationInfo(converter, kind, originalType);
1509+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1510+
std::move(valuesToMap));
15151511
return convertOp.getResults();
15161512
}
15171513

@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
26792675

26802676
static LogicalResult
26812677
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2682-
UnresolvedMaterializationRewrite *rewrite) {
2683-
UnrealizedConversionCastOp op = rewrite->getOperation();
2678+
UnrealizedConversionCastOp op,
2679+
const UnresolvedMaterializationInfo &info) {
26842680
assert(!op.use_empty() &&
26852681
"expected that dead materializations have already been DCE'd");
26862682
Operation::operand_range inputOperands = op.getOperands();
26872683

26882684
// Try to materialize the conversion.
2689-
if (const TypeConverter *converter = rewrite->getConverter()) {
2685+
if (const TypeConverter *converter = info.getConverter()) {
26902686
rewriter.setInsertionPoint(op);
26912687
SmallVector<Value> newMaterialization;
2692-
switch (rewrite->getMaterializationKind()) {
2688+
switch (info.getMaterializationKind()) {
26932689
case MaterializationKind::Target:
26942690
newMaterialization = converter->materializeTargetConversion(
26952691
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2696-
rewrite->getOriginalType());
2692+
info.getOriginalType());
26972693
break;
26982694
case MaterializationKind::Source:
26992695
assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27682764

27692765
// Gather all unresolved materializations.
27702766
SmallVector<UnrealizedConversionCastOp> allCastOps;
2771-
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2767+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
27722768
&materializations = rewriterImpl.unresolvedMaterializations;
27732769
for (auto it : materializations)
27742770
allCastOps.push_back(it.first);
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27852781
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
27862782
auto it = materializations.find(castOp);
27872783
assert(it != materializations.end() && "inconsistent state");
2788-
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2784+
if (failed(
2785+
legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
27892786
return failure();
27902787
}
27912788
}

0 commit comments

Comments
 (0)