Skip to content

Commit 2ce59d7

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately
1 parent 58c0bd1 commit 2ce59d7

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -789,26 +789,13 @@ enum MaterializationKind {
789789
Source
790790
};
791791

792-
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793-
/// op. Unresolved materializations are erased at the end of the dialect
794-
/// conversion.
795-
class UnresolvedMaterializationRewrite : public OperationRewrite {
792+
/// Helper class that stores metadata about an unresolved materialization.
793+
class UnresolvedMaterializationInfo {
796794
public:
797-
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
798-
UnrealizedConversionCastOp op,
799-
const TypeConverter *converter,
800-
MaterializationKind kind, Type originalType,
801-
ValueVector mappedValues);
802-
803-
static bool classof(const IRRewrite *rewrite) {
804-
return rewrite->getKind() == Kind::UnresolvedMaterialization;
805-
}
806-
807-
void rollback() override;
808-
809-
UnrealizedConversionCastOp getOperation() const {
810-
return cast<UnrealizedConversionCastOp>(op);
811-
}
795+
UnresolvedMaterializationInfo() = default;
796+
UnresolvedMaterializationInfo(const TypeConverter *converter,
797+
MaterializationKind kind, Type originalType)
798+
: converterAndKind(converter, kind), originalType(originalType) {}
812799

813800
/// Return the type converter of this materialization (which may be null).
814801
const TypeConverter *getConverter() const {
@@ -832,7 +819,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832819
/// The original type of the SSA value. Only used for target
833820
/// materializations.
834821
Type originalType;
822+
};
823+
824+
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825+
/// op. Unresolved materializations fold away or are replaced with
826+
/// source/target materializations at the end of the dialect conversion.
827+
class UnresolvedMaterializationRewrite : public OperationRewrite {
828+
public:
829+
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
830+
UnrealizedConversionCastOp op,
831+
ValueVector mappedValues)
832+
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833+
mappedValues(std::move(mappedValues)) {}
834+
835+
static bool classof(const IRRewrite *rewrite) {
836+
return rewrite->getKind() == Kind::UnresolvedMaterialization;
837+
}
838+
839+
void rollback() override;
835840

841+
UnrealizedConversionCastOp getOperation() const {
842+
return cast<UnrealizedConversionCastOp>(op);
843+
}
844+
845+
private:
836846
/// The values in the conversion value mapping that are being replaced by the
837847
/// results of this unresolved materialization.
838848
ValueVector mappedValues;
@@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10881098
/// by the current pattern.
10891099
SetVector<Block *> patternInsertedBlocks;
10901100

1091-
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092-
/// to the corresponding rewrite objects.
1093-
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1101+
/// A mapping for looking up metadata of unresolved materializations.
1102+
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
10941103
unresolvedMaterializations;
10951104

10961105
/// The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
12101219
op->erase();
12111220
}
12121221

1213-
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
1214-
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215-
const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216-
ValueVector mappedValues)
1217-
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218-
converterAndKind(converter, kind), originalType(originalType),
1219-
mappedValues(std::move(mappedValues)) {
1220-
assert((!originalType || kind == MaterializationKind::Target) &&
1221-
"original type is valid only for target materializations");
1222-
rewriterImpl.unresolvedMaterializations[op] = this;
1223-
}
1224-
12251222
void UnresolvedMaterializationRewrite::rollback() {
12261223
if (!mappedValues.empty())
12271224
rewriterImpl.mapping.erase(mappedValues);
@@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15101507
mapping.map(valuesToMap, convertOp.getResults());
15111508
if (castOp)
15121509
*castOp = convertOp;
1513-
appendRewrite<UnresolvedMaterializationRewrite>(
1514-
convertOp, converter, kind, originalType, std::move(valuesToMap));
1510+
unresolvedMaterializations[convertOp] =
1511+
UnresolvedMaterializationInfo(converter, kind, originalType);
1512+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1513+
std::move(valuesToMap));
15151514
return convertOp.getResults();
15161515
}
15171516

@@ -2679,21 +2678,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
26792678

26802679
static LogicalResult
26812680
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
2682-
UnresolvedMaterializationRewrite *rewrite) {
2683-
UnrealizedConversionCastOp op = rewrite->getOperation();
2681+
UnrealizedConversionCastOp op,
2682+
const UnresolvedMaterializationInfo &info) {
26842683
assert(!op.use_empty() &&
26852684
"expected that dead materializations have already been DCE'd");
26862685
Operation::operand_range inputOperands = op.getOperands();
26872686

26882687
// Try to materialize the conversion.
2689-
if (const TypeConverter *converter = rewrite->getConverter()) {
2688+
if (const TypeConverter *converter = info.getConverter()) {
26902689
rewriter.setInsertionPoint(op);
26912690
SmallVector<Value> newMaterialization;
2692-
switch (rewrite->getMaterializationKind()) {
2691+
switch (info.getMaterializationKind()) {
26932692
case MaterializationKind::Target:
26942693
newMaterialization = converter->materializeTargetConversion(
26952694
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
2696-
rewrite->getOriginalType());
2695+
info.getOriginalType());
26972696
break;
26982697
case MaterializationKind::Source:
26992698
assert(op->getNumResults() == 1 && "expected single result");
@@ -2768,7 +2767,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27682767

27692768
// Gather all unresolved materializations.
27702769
SmallVector<UnrealizedConversionCastOp> allCastOps;
2771-
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2770+
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
27722771
&materializations = rewriterImpl.unresolvedMaterializations;
27732772
for (auto it : materializations)
27742773
allCastOps.push_back(it.first);
@@ -2785,7 +2784,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27852784
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
27862785
auto it = materializations.find(castOp);
27872786
assert(it != materializations.end() && "inconsistent state");
2788-
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
2787+
if (failed(
2788+
legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
27892789
return failure();
27902790
}
27912791
}

0 commit comments

Comments
 (0)