Skip to content

Commit 2a9e284

Browse files
anmyachevMogball
andauthored
[LAYOUTS] Don't hoist into ifs outside of loops (#5801) (#4332)
git cherry-pick 032fa41 --------- Signed-off-by: Anatoly Myachev <[email protected]> Co-authored-by: Jeff Niu <[email protected]>
2 parents cdc090a + 278a05b commit 2a9e284

File tree

3 files changed

+57
-125
lines changed

3 files changed

+57
-125
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,27 +1458,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14581458
// These are the conditional edges above which conversions should be hoisted.
14591459
// The value represents the `scf.if` op result and the operand represents the
14601460
// edge into one of the branches.
1461-
SmallVector<std::pair<OpResult, OpOperand *>> hoistAbove;
1461+
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
14621462

14631463
// The list of `scf.if` op results in the slice that are not rematerializable.
14641464
// Hoisting is terminated at these values.
14651465
SmallVector<OpResult> terminals;
14661466

1467-
// Process the whole backward slice in subslices that stop at each condtional.
1468-
// This is so we can apply more specific rules about when to hoist.
1469-
struct Subslice {
1470-
OpResult v;
1471-
OpOperand *edge;
1472-
SetVector<Value> slice;
1473-
DenseMap<Value, Attribute> layout;
1474-
};
1475-
SmallVector<Subslice> subslices;
1476-
1477-
// Check a value in the subslice.
1478-
auto visitValue = [&](OpResult v) {
1467+
// This loop recurses through the subslices of the backwards dependencies, so
1468+
// re-query the size of `slice`.
1469+
for (unsigned i = 0; i != slice.size(); ++i) {
1470+
Value v = slice[i];
14791471
auto ifOp = v.getDefiningOp<scf::IfOp>();
14801472
if (!ifOp)
1481-
return;
1473+
continue;
14821474

14831475
Attribute rootLayout = layout.at(v);
14841476
unsigned resIdx = cast<OpResult>(v).getResultNumber();
@@ -1507,66 +1499,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15071499
slice.insert(elseSlice.begin(), elseSlice.end());
15081500
layout.insert(thenLayout.begin(), thenLayout.end());
15091501
layout.insert(elseLayout.begin(), elseLayout.end());
1510-
return;
1502+
continue;
15111503
}
15121504

15131505
// If propagation across both edges failed, then this conditional
15141506
// terminates backwards rematerialization.
15151507
if (failed(thenResult) && failed(elseResult)) {
1516-
terminals.push_back(v);
1517-
return;
1508+
terminals.push_back(cast<OpResult>(v));
1509+
continue;
1510+
}
1511+
1512+
// Only hoist into conditionals inside loops. The assumption is that an if
1513+
// inside a loop executes fewer than the total number of loop iterations,
1514+
// making this hoist profitable.
1515+
if (!isa<scf::ForOp>(ifOp->getParentOp())) {
1516+
terminals.push_back(cast<OpResult>(v));
1517+
continue;
15181518
}
15191519

15201520
// The layout conversion can be rematerialized along one edge but not the
15211521
// other. We can hoist the conversion into the other branch. Push this
15221522
// into the subslice list for analysis.
15231523
if (succeeded(thenResult)) {
1524-
subslices.push_back(
1525-
{v, &elseRes, std::move(thenSlice), std::move(thenLayout)});
1524+
hoistAbove.emplace_back(v, &elseRes);
1525+
slice.insert(thenSlice.begin(), thenSlice.end());
1526+
layout.insert(thenLayout.begin(), thenLayout.end());
15261527
} else {
1527-
subslices.push_back(
1528-
{v, &thenRes, std::move(elseSlice), std::move(elseLayout)});
1529-
}
1530-
};
1531-
1532-
// Process the whole slice in subslices.
1533-
unsigned i = 0;
1534-
bool isLoneHoist = false;
1535-
do {
1536-
// Visit values in the current subslice.
1537-
for (; i != slice.size(); ++i) {
1538-
if (auto v = dyn_cast<OpResult>(slice[i]))
1539-
visitValue(v);
1540-
}
1541-
// Check the next chunk of subslices. When a condtional is marked as being
1542-
// valid to be hoisted across, we have to recurse on a new subslice rooted
1543-
// at the corresopnding yield operand.
1544-
//
1545-
// Hoist across condtionals when:
1546-
// 1. The conditional is directly inside a loop.
1547-
// 2. The whole slice contains only one conditional.
1548-
for (auto &[v, edge, subslice, layouts] : subslices) {
1549-
bool oneHoist = false;
1550-
if (isa<LoopLikeOpInterface>(v.getDefiningOp()->getParentOp()) ||
1551-
(oneHoist = subslices.size() == 1 && hoistAbove.empty())) {
1552-
isLoneHoist |= oneHoist;
1553-
hoistAbove.push_back({v, edge});
1554-
// Recurse on the subslice.
1555-
slice.insert(subslice.begin(), subslice.end());
1556-
layout.insert(layouts.begin(), layouts.end());
1557-
} else {
1558-
terminals.push_back(v);
1559-
}
1528+
hoistAbove.emplace_back(v, &thenRes);
1529+
slice.insert(elseSlice.begin(), elseSlice.end());
1530+
layout.insert(elseLayout.begin(), elseLayout.end());
15601531
}
1561-
subslices.clear();
1562-
} while (i != slice.size());
1532+
}
15631533

15641534
// Exit early if there is nothing to do.
15651535
if (hoistAbove.empty())
15661536
return;
1567-
// Check if this is a lone hoist. There should be no other terminals.
1568-
if (isLoneHoist && !terminals.empty())
1569-
return;
15701537

15711538
// Rematerialize failed hoists right before the condtional, and hoist those
15721539
// that succeeded into the branch and then rewrite the slice.

test/TritonGPU/combine.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,27 +2874,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
28742874
// CHECK-LABEL: @hoist_one_conditional
28752875
tt.func @hoist_one_conditional(
28762876
%arg0: i1,
2877-
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>,
2878-
%arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>,
2879-
%arg3: tensor<128x128xf32, #mma>
2880-
) -> tensor<128x128xf32, #mma> {
2877+
%arg1: tensor<128x32x!tt.ptr<f32>, #blocked>
2878+
) -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> {
28812879

2882-
// CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
2880+
// CHECK: arith.constant {{.*}} tensor<128x32xf32, #blocked>
28832881
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
28842882
// CHECK: scf.if
28852883
%0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) {
28862884
// CHECK-NEXT: [[RES:%.*]] = tt.load
28872885
%3 = tt.load %arg1 : tensor<128x32x!tt.ptr<f32>, #blocked>
2888-
// CHECK-NEXT: ttg.convert_layout [[RES]]
2889-
// CHECK-NEXT: yield
2886+
// CHECK-NEXT: yield [[RES]]
28902887
scf.yield %3 : tensor<128x32xf32, #blocked>
28912888
} else {
28922889
scf.yield %cst : tensor<128x32xf32, #blocked>
28932890
}
2894-
// CHECK-NOT: ttg.convert_layout
2895-
%1 = ttg.convert_layout %0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2896-
%2 = tt.dot %1, %arg2, %arg3 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
2897-
tt.return %2 : tensor<128x128xf32, #mma>
2891+
// CHECK: [[TRUNC:%.*]] = arith.truncf
2892+
%1 = arith.truncf %0 : tensor<128x32xf32, #blocked> to tensor<128x32xf16, #blocked>
2893+
// CHECK-NEXT: convert_layout [[TRUNC]]
2894+
%2 = ttg.convert_layout %1 : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
2895+
tt.return %2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
28982896
}
28992897

29002898
// CHECK-LABEL: @hoist_multiple_conditional

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 24 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,27 +1458,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14581458
// These are the conditional edges above which conversions should be hoisted.
14591459
// The value represents the `scf.if` op result and the operand represents the
14601460
// edge into one of the branches.
1461-
SmallVector<std::pair<OpResult, OpOperand *>> hoistAbove;
1461+
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
14621462

14631463
// The list of `scf.if` op results in the slice that are not rematerializable.
14641464
// Hoisting is terminated at these values.
14651465
SmallVector<OpResult> terminals;
14661466

1467-
// Process the whole backward slice in subslices that stop at each condtional.
1468-
// This is so we can apply more specific rules about when to hoist.
1469-
struct Subslice {
1470-
OpResult v;
1471-
OpOperand *edge;
1472-
SetVector<Value> slice;
1473-
DenseMap<Value, Attribute> layout;
1474-
};
1475-
SmallVector<Subslice> subslices;
1476-
1477-
// Check a value in the subslice.
1478-
auto visitValue = [&](OpResult v) {
1467+
// This loop recurses through the subslices of the backwards dependencies, so
1468+
// re-query the size of `slice`.
1469+
for (unsigned i = 0; i != slice.size(); ++i) {
1470+
Value v = slice[i];
14791471
auto ifOp = v.getDefiningOp<scf::IfOp>();
14801472
if (!ifOp)
1481-
return;
1473+
continue;
14821474

14831475
Attribute rootLayout = layout.at(v);
14841476
unsigned resIdx = cast<OpResult>(v).getResultNumber();
@@ -1507,66 +1499,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15071499
slice.insert(elseSlice.begin(), elseSlice.end());
15081500
layout.insert(thenLayout.begin(), thenLayout.end());
15091501
layout.insert(elseLayout.begin(), elseLayout.end());
1510-
return;
1502+
continue;
15111503
}
15121504

15131505
// If propagation across both edges failed, then this conditional
15141506
// terminates backwards rematerialization.
15151507
if (failed(thenResult) && failed(elseResult)) {
1516-
terminals.push_back(v);
1517-
return;
1508+
terminals.push_back(cast<OpResult>(v));
1509+
continue;
1510+
}
1511+
1512+
// Only hoist into conditionals inside loops. The assumption is that an if
1513+
// inside a loop executes fewer than the total number of loop iterations,
1514+
// making this hoist profitable.
1515+
if (!isa<scf::ForOp>(ifOp->getParentOp())) {
1516+
terminals.push_back(cast<OpResult>(v));
1517+
continue;
15181518
}
15191519

15201520
// The layout conversion can be rematerialized along one edge but not the
15211521
// other. We can hoist the conversion into the other branch. Push this
15221522
// into the subslice list for analysis.
15231523
if (succeeded(thenResult)) {
1524-
subslices.push_back(
1525-
{v, &elseRes, std::move(thenSlice), std::move(thenLayout)});
1524+
hoistAbove.emplace_back(v, &elseRes);
1525+
slice.insert(thenSlice.begin(), thenSlice.end());
1526+
layout.insert(thenLayout.begin(), thenLayout.end());
15261527
} else {
1527-
subslices.push_back(
1528-
{v, &thenRes, std::move(elseSlice), std::move(elseLayout)});
1529-
}
1530-
};
1531-
1532-
// Process the whole slice in subslices.
1533-
unsigned i = 0;
1534-
bool isLoneHoist = false;
1535-
do {
1536-
// Visit values in the current subslice.
1537-
for (; i != slice.size(); ++i) {
1538-
if (auto v = dyn_cast<OpResult>(slice[i]))
1539-
visitValue(v);
1540-
}
1541-
// Check the next chunk of subslices. When a condtional is marked as being
1542-
// valid to be hoisted across, we have to recurse on a new subslice rooted
1543-
// at the corresopnding yield operand.
1544-
//
1545-
// Hoist across condtionals when:
1546-
// 1. The conditional is directly inside a loop.
1547-
// 2. The whole slice contains only one conditional.
1548-
for (auto &[v, edge, subslice, layouts] : subslices) {
1549-
bool oneHoist = false;
1550-
if (isa<LoopLikeOpInterface>(v.getDefiningOp()->getParentOp()) ||
1551-
(oneHoist = subslices.size() == 1 && hoistAbove.empty())) {
1552-
isLoneHoist |= oneHoist;
1553-
hoistAbove.push_back({v, edge});
1554-
// Recurse on the subslice.
1555-
slice.insert(subslice.begin(), subslice.end());
1556-
layout.insert(layouts.begin(), layouts.end());
1557-
} else {
1558-
terminals.push_back(v);
1559-
}
1528+
hoistAbove.emplace_back(v, &thenRes);
1529+
slice.insert(elseSlice.begin(), elseSlice.end());
1530+
layout.insert(elseLayout.begin(), elseLayout.end());
15601531
}
1561-
subslices.clear();
1562-
} while (i != slice.size());
1532+
}
15631533

15641534
// Exit early if there is nothing to do.
15651535
if (hoistAbove.empty())
15661536
return;
1567-
// Check if this is a lone hoist. There should be no other terminals.
1568-
if (isLoneHoist && !terminals.empty())
1569-
return;
15701537

15711538
// Rematerialize failed hoists right before the condtional, and hoist those
15721539
// that succeeded into the branch and then rewrite the slice.

0 commit comments

Comments
 (0)