@@ -789,38 +789,22 @@ enum MaterializationKind {
789
789
Source
790
790
};
791
791
792
- // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
793
- // / op. Unresolved materializations are erased at the end of the dialect
794
- // / conversion.
795
- class UnresolvedMaterializationRewrite : public OperationRewrite {
792
+ // / Helper class that stores metadata about an unresolved materialization.
793
+ class UnresolvedMaterializationInfo {
796
794
public:
797
- UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
798
- UnrealizedConversionCastOp op,
799
- const TypeConverter *converter,
800
- MaterializationKind kind, Type originalType,
801
- ValueVector mappedValues);
802
-
803
- static bool classof (const IRRewrite *rewrite) {
804
- return rewrite->getKind () == Kind::UnresolvedMaterialization;
805
- }
806
-
807
- void rollback () override ;
808
-
809
- UnrealizedConversionCastOp getOperation () const {
810
- return cast<UnrealizedConversionCastOp>(op);
811
- }
795
+ UnresolvedMaterializationInfo () = default ;
796
+ UnresolvedMaterializationInfo (const TypeConverter *converter,
797
+ MaterializationKind kind, Type originalType)
798
+ : converterAndKind(converter, kind), originalType(originalType) {}
812
799
813
- // / Return the type converter of this materialization (which may be null).
814
800
const TypeConverter *getConverter () const {
815
801
return converterAndKind.getPointer ();
816
802
}
817
803
818
- // / Return the kind of this materialization.
819
804
MaterializationKind getMaterializationKind () const {
820
805
return converterAndKind.getInt ();
821
806
}
822
807
823
- // / Return the original type of the SSA value.
824
808
Type getOriginalType () const { return originalType; }
825
809
826
810
private:
@@ -832,7 +816,30 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
832
816
// / The original type of the SSA value. Only used for target
833
817
// / materializations.
834
818
Type originalType;
819
+ };
820
+
821
+ // / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
822
+ // / op. Unresolved materializations fold away or are replaced with
823
+ // / source/target materializations at the end of the dialect conversion.
824
+ class UnresolvedMaterializationRewrite : public OperationRewrite {
825
+ public:
826
+ UnresolvedMaterializationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
827
+ UnrealizedConversionCastOp op,
828
+ ValueVector mappedValues)
829
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
830
+ mappedValues (std::move(mappedValues)) {}
831
+
832
+ static bool classof (const IRRewrite *rewrite) {
833
+ return rewrite->getKind () == Kind::UnresolvedMaterialization;
834
+ }
835
+
836
+ void rollback () override ;
835
837
838
+ UnrealizedConversionCastOp getOperation () const {
839
+ return cast<UnrealizedConversionCastOp>(op);
840
+ }
841
+
842
+ private:
836
843
// / The values in the conversion value mapping that are being replaced by the
837
844
// / results of this unresolved materialization.
838
845
ValueVector mappedValues;
@@ -1088,9 +1095,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1088
1095
// / by the current pattern.
1089
1096
SetVector<Block *> patternInsertedBlocks;
1090
1097
1091
- // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1092
- // / to the corresponding rewrite objects.
1093
- DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
1098
+ // / A mapping for looking up metadata of unresolved materializations.
1099
+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
1094
1100
unresolvedMaterializations;
1095
1101
1096
1102
// / The current type converter, or nullptr if no type converter is currently
@@ -1210,18 +1216,6 @@ void CreateOperationRewrite::rollback() {
1210
1216
op->erase ();
1211
1217
}
1212
1218
1213
- UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1214
- ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1215
- const TypeConverter *converter, MaterializationKind kind, Type originalType,
1216
- ValueVector mappedValues)
1217
- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1218
- converterAndKind(converter, kind), originalType(originalType),
1219
- mappedValues(std::move(mappedValues)) {
1220
- assert ((!originalType || kind == MaterializationKind::Target) &&
1221
- " original type is valid only for target materializations" );
1222
- rewriterImpl.unresolvedMaterializations [op] = this ;
1223
- }
1224
-
1225
1219
void UnresolvedMaterializationRewrite::rollback () {
1226
1220
if (!mappedValues.empty ())
1227
1221
rewriterImpl.mapping .erase (mappedValues);
@@ -1510,8 +1504,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1510
1504
mapping.map (valuesToMap, convertOp.getResults ());
1511
1505
if (castOp)
1512
1506
*castOp = convertOp;
1513
- appendRewrite<UnresolvedMaterializationRewrite>(
1514
- convertOp, converter, kind, originalType, std::move (valuesToMap));
1507
+ unresolvedMaterializations[convertOp] =
1508
+ UnresolvedMaterializationInfo (converter, kind, originalType);
1509
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
1510
+ std::move (valuesToMap));
1515
1511
return convertOp.getResults ();
1516
1512
}
1517
1513
@@ -2679,21 +2675,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2679
2675
2680
2676
static LogicalResult
2681
2677
legalizeUnresolvedMaterialization (RewriterBase &rewriter,
2682
- UnresolvedMaterializationRewrite *rewrite) {
2683
- UnrealizedConversionCastOp op = rewrite-> getOperation ();
2678
+ UnrealizedConversionCastOp op,
2679
+ const UnresolvedMaterializationInfo &info) {
2684
2680
assert (!op.use_empty () &&
2685
2681
" expected that dead materializations have already been DCE'd" );
2686
2682
Operation::operand_range inputOperands = op.getOperands ();
2687
2683
2688
2684
// Try to materialize the conversion.
2689
- if (const TypeConverter *converter = rewrite-> getConverter ()) {
2685
+ if (const TypeConverter *converter = info. getConverter ()) {
2690
2686
rewriter.setInsertionPoint (op);
2691
2687
SmallVector<Value> newMaterialization;
2692
- switch (rewrite-> getMaterializationKind ()) {
2688
+ switch (info. getMaterializationKind ()) {
2693
2689
case MaterializationKind::Target:
2694
2690
newMaterialization = converter->materializeTargetConversion (
2695
2691
rewriter, op->getLoc (), op.getResultTypes (), inputOperands,
2696
- rewrite-> getOriginalType ());
2692
+ info. getOriginalType ());
2697
2693
break ;
2698
2694
case MaterializationKind::Source:
2699
2695
assert (op->getNumResults () == 1 && " expected single result" );
@@ -2768,7 +2764,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2768
2764
2769
2765
// Gather all unresolved materializations.
2770
2766
SmallVector<UnrealizedConversionCastOp> allCastOps;
2771
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite * >
2767
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo >
2772
2768
&materializations = rewriterImpl.unresolvedMaterializations ;
2773
2769
for (auto it : materializations)
2774
2770
allCastOps.push_back (it.first );
@@ -2785,7 +2781,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2785
2781
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2786
2782
auto it = materializations.find (castOp);
2787
2783
assert (it != materializations.end () && " inconsistent state" );
2788
- if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
2784
+ if (failed (
2785
+ legalizeUnresolvedMaterialization (rewriter, castOp, it->second )))
2789
2786
return failure ();
2790
2787
}
2791
2788
}
0 commit comments