Skip to content

Commit 6a97b56

Browse files
authored
[MLIR][AMDGPU] Redirect transfer read to masked load lowering (#146705)
This PR reworks #131803. Instead of applying the optimization on transfer_read op, which is too high level, it redirect the pre-existing pattern onto maskedload op. This simplified the implementation of the lowering pattern. This also allows moving the usage of the pass to a target dependent pipeline. Signed-off-by: jerryyin <[email protected]>
1 parent dfcef35 commit 6a97b56

File tree

6 files changed

+199
-297
lines changed

6 files changed

+199
-297
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace amdgpu {
2323

2424
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
2525
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
26-
#define GEN_PASS_DECL_AMDGPUTRANSFERREADTOLOADPASS
26+
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
2727
#define GEN_PASS_REGISTRATION
2828
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
2929

@@ -35,8 +35,8 @@ void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
3535
void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
3636
PatternBenefit benefit = 1);
3737

38-
void populateAmdgpuTransferReadToLoadPatterns(RewritePatternSet &patterns,
39-
PatternBenefit benefit = 1);
38+
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
39+
PatternBenefit benefit = 1);
4040

4141
} // namespace amdgpu
4242
} // namespace mlir

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
5151
];
5252
}
5353

54-
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
55-
let summary = "Lower the operations from the vector transfer_read to vector load";
54+
def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
55+
let summary = "Lower the operations from the vector maskedload to vector load";
5656
let description = [{
5757
This pass creates a transfer read op lowering optimization. The lowering
5858
will produce a conditional check at runtime. If within bounds, a vector

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
33
ResolveStridedMetadata.cpp
4-
TransferReadToLoad.cpp
4+
MaskedloadToLoad.cpp
55

66
ADDITIONAL_HEADER_DIRS
77
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
16+
#include "mlir/Dialect/SCF/IR/SCF.h"
17+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/IR/OpDefinition.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeUtilities.h"
22+
#include "mlir/Pass/Pass.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
#include "llvm/Support/MathExtras.h"
26+
27+
namespace mlir::amdgpu {
28+
#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
29+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
30+
} // namespace mlir::amdgpu
31+
32+
using namespace mlir;
33+
using namespace mlir::amdgpu;
34+
35+
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
36+
/// and `arith.select` if the memref is in buffer address space.
37+
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
38+
vector::MaskedLoadOp maskedOp) {
39+
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
40+
if (!memRefType)
41+
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
42+
43+
Attribute addrSpace = memRefType.getMemorySpace();
44+
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
45+
return rewriter.notifyMatchFailure(maskedOp, "no address space");
46+
47+
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
48+
amdgpu::AddressSpace::FatRawBuffer)
49+
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
50+
51+
return success();
52+
}
53+
54+
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
55+
vector::MaskedLoadOp maskedOp) {
56+
VectorType vectorType = maskedOp.getVectorType();
57+
Value load = builder.create<vector::LoadOp>(
58+
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59+
Value res = builder.create<arith::SelectOp>(
60+
loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
61+
return res;
62+
}
63+
64+
static constexpr char kMaskedloadNeedsMask[] =
65+
"amdgpu.buffer_maskedload_needs_mask";
66+
67+
namespace {
68+
69+
struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
70+
using OpRewritePattern::OpRewritePattern;
71+
72+
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
73+
PatternRewriter &rewriter) const override {
74+
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
75+
return failure();
76+
77+
if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
78+
return failure();
79+
}
80+
81+
Location loc = maskedOp.getLoc();
82+
Value src = maskedOp.getBase();
83+
84+
VectorType vectorType = maskedOp.getVectorType();
85+
int64_t vectorSize = vectorType.getNumElements();
86+
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
87+
SmallVector<OpFoldResult> indices = maskedOp.getIndices();
88+
89+
auto stridedMetadata =
90+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
91+
SmallVector<OpFoldResult> strides =
92+
stridedMetadata.getConstifiedMixedStrides();
93+
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
94+
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
95+
memref::LinearizedMemRefInfo linearizedInfo;
96+
OpFoldResult linearizedIndices;
97+
std::tie(linearizedInfo, linearizedIndices) =
98+
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
99+
elementBitWidth, offset, sizes,
100+
strides, indices);
101+
102+
// delta = bufferSize - linearizedOffset
103+
Value vectorSizeOffset =
104+
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
105+
Value linearIndex =
106+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
107+
Value totalSize = getValueOrCreateConstantIndexOp(
108+
rewriter, loc, linearizedInfo.linearizedSize);
109+
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
110+
111+
// 1) check if delta < vectorSize
112+
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
113+
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
114+
115+
// 2) check if (detla % elements_per_word != 0)
116+
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
117+
loc, llvm::divideCeil(32, elementBitWidth));
118+
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
119+
loc, arith::CmpIPredicate::ne,
120+
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
121+
rewriter.create<arith::ConstantIndexOp>(loc, 0));
122+
123+
// We take the fallback of maskedload default lowering only it is both
124+
// out-of-bounds and not word aligned. The fallback ensures correct results
125+
// when loading at the boundary of the buffer since buffer load returns
126+
// inconsistent zeros for the whole word when boundary is crossed.
127+
Value ifCondition =
128+
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
129+
130+
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
131+
Operation *read = builder.clone(*maskedOp.getOperation());
132+
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
133+
Value readResult = read->getResult(0);
134+
builder.create<scf::YieldOp>(loc, readResult);
135+
};
136+
137+
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
138+
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
139+
rewriter.create<scf::YieldOp>(loc, res);
140+
};
141+
142+
auto ifOp =
143+
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
144+
145+
rewriter.replaceOp(maskedOp, ifOp);
146+
147+
return success();
148+
}
149+
};
150+
151+
} // namespace
152+
153+
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
154+
RewritePatternSet &patterns, PatternBenefit benefit) {
155+
patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
156+
}
157+
158+
struct AmdgpuMaskedloadToLoadPass final
159+
: amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
160+
void runOnOperation() override {
161+
RewritePatternSet patterns(&getContext());
162+
populateAmdgpuMaskedloadToLoadPatterns(patterns);
163+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
164+
return signalPassFailure();
165+
}
166+
}
167+
};

0 commit comments

Comments
 (0)