@@ -789,26 +789,13 @@ enum MaterializationKind {
789
789
Source
790
790
};
791
791
792
- // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793
- // / op. Unresolved materializations are erased at the end of the dialect
794
- // / conversion.
795
- class UnresolvedMaterializationRewrite : public OperationRewrite {
792
+ // / Helper class that stores metadata about an unresolved materialization.
793
+ class UnresolvedMaterializationInfo {
796
794
public:
797
- UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
798
- UnrealizedConversionCastOp op,
799
- const TypeConverter *converter,
800
- MaterializationKind kind, Type originalType,
801
- ValueVector mappedValues);
802
-
803
- static bool classof (const IRRewrite *rewrite) {
804
- return rewrite->getKind () == Kind::UnresolvedMaterialization;
805
- }
806
-
807
- void rollback () override ;
808
-
809
- UnrealizedConversionCastOp getOperation () const {
810
- return cast<UnrealizedConversionCastOp>(op);
811
- }
795
+ UnresolvedMaterializationInfo () = default ;
796
+ UnresolvedMaterializationInfo (const TypeConverter *converter,
797
+ MaterializationKind kind, Type originalType)
798
+ : converterAndKind(converter, kind), originalType(originalType) {}
812
799
813
800
// / Return the type converter of this materialization (which may be null).
814
801
const TypeConverter *getConverter () const {
@@ -832,7 +819,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832
819
// / The original type of the SSA value. Only used for target
833
820
// / materializations.
834
821
Type originalType;
822
+ };
823
+
824
+ // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
825
+ // / op. Unresolved materializations fold away or are replaced with
826
+ // / source/target materializations at the end of the dialect conversion.
827
+ class UnresolvedMaterializationRewrite : public OperationRewrite {
828
+ public:
829
+ UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
830
+ UnrealizedConversionCastOp op,
831
+ ValueVector mappedValues)
832
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
833
+ mappedValues (std::move(mappedValues)) {}
834
+
835
+ static bool classof (const IRRewrite *rewrite) {
836
+ return rewrite->getKind () == Kind::UnresolvedMaterialization;
837
+ }
838
+
839
+ void rollback () override ;
835
840
841
+ UnrealizedConversionCastOp getOperation () const {
842
+ return cast<UnrealizedConversionCastOp>(op);
843
+ }
844
+
845
+ private:
836
846
// / The values in the conversion value mapping that are being replaced by the
837
847
// / results of this unresolved materialization.
838
848
ValueVector mappedValues;
@@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1088
1098
// / by the current pattern.
1089
1099
SetVector<Block *> patternInsertedBlocks;
1090
1100
1091
- // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092
- // / to the corresponding rewrite objects.
1093
- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1101
+ // / A mapping for looking up metadata of unresolved materializations.
1102
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1094
1103
unresolvedMaterializations;
1095
1104
1096
1105
// / The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
1210
1219
op->erase ();
1211
1220
}
1212
1221
1213
- UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1214
- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215
- const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216
- ValueVector mappedValues)
1217
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218
- converterAndKind(converter, kind), originalType(originalType),
1219
- mappedValues(std::move(mappedValues)) {
1220
- assert ((!originalType || kind == MaterializationKind::Target) &&
1221
- " original type is valid only for target materializations" );
1222
- rewriterImpl.unresolvedMaterializations [op] = this ;
1223
- }
1224
-
1225
1222
void UnresolvedMaterializationRewrite::rollback () {
1226
1223
if (!mappedValues.empty ())
1227
1224
rewriterImpl.mapping .erase (mappedValues);
@@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1510
1507
mapping.map (valuesToMap, convertOp.getResults ());
1511
1508
if (castOp)
1512
1509
*castOp = convertOp;
1513
- appendRewrite<UnresolvedMaterializationRewrite>(
1514
- convertOp, converter, kind, originalType, std::move (valuesToMap));
1510
+ unresolvedMaterializations[convertOp] =
1511
+ UnresolvedMaterializationInfo (converter, kind, originalType);
1512
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1513
+ std::move (valuesToMap));
1515
1514
return convertOp.getResults ();
1516
1515
}
1517
1516
@@ -2679,21 +2678,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2679
2678
2680
2679
static LogicalResult
2681
2680
legalizeUnresolvedMaterialization (RewriterBase &rewriter,
2682
- UnresolvedMaterializationRewrite *rewrite) {
2683
- UnrealizedConversionCastOp op = rewrite-> getOperation ();
2681
+ UnrealizedConversionCastOp op,
2682
+ const UnresolvedMaterializationInfo &info) {
2684
2683
assert (!op.use_empty () &&
2685
2684
" expected that dead materializations have already been DCE'd" );
2686
2685
Operation::operand_range inputOperands = op.getOperands ();
2687
2686
2688
2687
// Try to materialize the conversion.
2689
- if (const TypeConverter *converter = rewrite-> getConverter ()) {
2688
+ if (const TypeConverter *converter = info. getConverter ()) {
2690
2689
rewriter.setInsertionPoint (op);
2691
2690
SmallVector<Value> newMaterialization;
2692
- switch (rewrite-> getMaterializationKind ()) {
2691
+ switch (info. getMaterializationKind ()) {
2693
2692
case MaterializationKind::Target:
2694
2693
newMaterialization = converter->materializeTargetConversion (
2695
2694
rewriter, op->getLoc (), op.getResultTypes (), inputOperands,
2696
- rewrite-> getOriginalType ());
2695
+ info. getOriginalType ());
2697
2696
break ;
2698
2697
case MaterializationKind::Source:
2699
2698
assert (op->getNumResults () == 1 && " expected single result" );
@@ -2768,7 +2767,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2768
2767
2769
2768
// Gather all unresolved materializations.
2770
2769
SmallVector<UnrealizedConversionCastOp> allCastOps;
2771
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * >
2770
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo >
2772
2771
&materializations = rewriterImpl.unresolvedMaterializations ;
2773
2772
for (auto it : materializations)
2774
2773
allCastOps.push_back (it.first );
@@ -2785,7 +2784,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2785
2784
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2786
2785
auto it = materializations.find (castOp);
2787
2786
assert (it != materializations.end () && " inconsistent state" );
2788
- if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
2787
+ if (failed (
2788
+ legalizeUnresolvedMaterialization (rewriter, castOp, it->second )))
2789
2789
return failure ();
2790
2790
}
2791
2791
}
0 commit comments