Skip to content

Commit f763e81

Browse files
Merge commit '7abb0be809e0b2c4fe734b1840750008bd590c7c'
2 parents b483233 + 7abb0be commit f763e81

File tree

24 files changed

+753
-622
lines changed

24 files changed

+753
-622
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ test-distributed: all
5757
test-gluon: all
5858
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
5959
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
60+
$(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon
6061

6162
.PHONY: test-regression
6263
test-regression: all

include/triton/Analysis/Membar.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ struct AllocationSlice {
4747
private:
4848
std::tuple<Interval<size_t>, const void *, llvm::ArrayRef<int64_t>>
4949
asTuple() const {
50-
return std::make_tuple(allocationInterval, accessTy.getAsOpaquePointer(),
51-
subsliceOffsets);
50+
return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets};
5251
}
5352
// Offsets from subslice. Empty when offsets are unknown
5453
SmallVector<int64_t> subsliceOffsets;

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) {
1616
return mmaEnc && mmaEnc.getFp4Padded();
1717
}
1818

19-
SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
20-
Attribute encoding,
21-
SmallVector<Value> indices);
22-
2319
gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout,
2420
ArrayRef<int64_t> shape);
2521

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,14 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp,
241241
Value alloc, Value insertIdx, Value extractIdx,
242242
Value barrier, Operation *waitOp,
243243
CoarseSchedule &schedule) {
244-
return createTMAAsyncCopy(
245-
forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier,
246-
waitOp, schedule,
247-
[&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view,
248-
Value pred) {
249-
auto indices = ttng::translateTMAIndices(
250-
builder, loadOp.getLoc(),
251-
loadOp.getDesc().getType().getBlockType().getEncoding(),
252-
loadOp.getIndices());
253-
ttng::AsyncTMACopyGlobalToLocalOp::create(
254-
builder, loadOp.getLoc(), tmaPtr, indices, barrier, view, pred);
255-
});
244+
return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx,
245+
extractIdx, barrier, waitOp, schedule,
246+
[&](OpBuilderForStage &builder, Value desc,
247+
Value barrier, Value view, Value pred) {
248+
ttng::AsyncTMACopyGlobalToLocalOp::create(
249+
builder, loadOp.getLoc(), desc,
250+
loadOp.getIndices(), barrier, view, pred);
251+
});
256252
}
257253

258254
void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
@@ -261,10 +257,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
261257
CoarseSchedule &schedule) {
262258
return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc,
263259
insertIdx, extractIdx, barrier, waitOp, schedule,
264-
[&](OpBuilderForStage &builder, Value tmaPtr,
260+
[&](OpBuilderForStage &builder, Value desc,
265261
Value barrier, Value view, Value pred) {
266262
ttng::AsyncTMAGatherOp::create(
267-
builder, gatherOp.getLoc(), tmaPtr,
263+
builder, gatherOp.getLoc(), desc,
268264
gatherOp.getXOffsets(), gatherOp.getYOffset(),
269265
barrier, view, pred);
270266
});

lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,
6060
ttng::FenceAsyncSharedOp::create(builder, loc, false);
6161
auto desc = store.desc;
6262
if (auto storeOp = dyn_cast<tt::DescriptorStoreOp>(store.op)) {
63-
auto indices = ttng::translateTMAIndices(
64-
builder, storeOp.getLoc(),
65-
storeOp.getDesc().getType().getBlockType().getEncoding(),
66-
storeOp.getIndices());
6763
ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc,
6864
storeOp.getIndices(), alloc);
6965
} else if (auto reduceOp = dyn_cast<tt::DescriptorReduceOp>(store.op)) {
70-
auto indices = ttng::translateTMAIndices(
71-
builder, reduceOp.getLoc(),
72-
reduceOp.getDesc().getType().getBlockType().getEncoding(),
73-
reduceOp.getIndices());
7466
ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc,
7567
reduceOp.getIndices(), alloc);
7668
} else {

lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,11 @@ class TMALoadLowering : public OpRewritePattern<DescriptorLoadOp> {
6868
LogicalResult matchAndRewrite(DescriptorLoadOp op,
6969
PatternRewriter &rewriter) const override {
7070
auto loc = op.getLoc();
71-
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
71+
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
7272
Value pred) {
73-
auto indices = translateTMAIndices(
74-
rewriter, op.getLoc(),
75-
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
7673
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create(
77-
rewriter, op.getLoc(), tmaPtr, indices, barrierAlloc, alloc, pred);
74+
rewriter, op.getLoc(), desc, op.getIndices(), barrierAlloc, alloc,
75+
pred);
7876
};
7977
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
8078
return success();
@@ -86,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern<DescriptorGatherOp> {
8684

8785
LogicalResult matchAndRewrite(DescriptorGatherOp op,
8886
PatternRewriter &rewriter) const override {
89-
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
87+
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
9088
Value pred) {
9189
triton::nvidia_gpu::AsyncTMAGatherOp::create(
92-
rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(),
90+
rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(),
9391
barrierAlloc, alloc, pred);
9492
};
9593
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
@@ -122,12 +120,9 @@ struct TMAStoreLowering : public OpRewritePattern<DescriptorStoreOp> {
122120

123121
LogicalResult matchAndRewrite(DescriptorStoreOp op,
124122
PatternRewriter &rewriter) const override {
125-
auto createStore = [&](Value tmaPtr, Value alloc) {
126-
auto indices = translateTMAIndices(
127-
rewriter, op.getLoc(),
128-
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
123+
auto createStore = [&](Value desc, Value alloc) {
129124
triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create(
130-
rewriter, op.getLoc(), tmaPtr, indices, alloc);
125+
rewriter, op.getLoc(), desc, op.getIndices(), alloc);
131126
};
132127
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
133128
return success();
@@ -139,12 +134,9 @@ struct TMAReduceLowering : public OpRewritePattern<DescriptorReduceOp> {
139134

140135
LogicalResult matchAndRewrite(DescriptorReduceOp op,
141136
PatternRewriter &rewriter) const override {
142-
auto createStore = [&](Value tmaPtr, Value alloc) {
143-
auto indices = translateTMAIndices(
144-
rewriter, op.getLoc(),
145-
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
137+
auto createStore = [&](Value desc, Value alloc) {
146138
triton::nvidia_gpu::AsyncTMAReduceOp::create(
147-
rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc);
139+
rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc);
148140
};
149141
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
150142
return success();
@@ -156,9 +148,9 @@ struct TMAScatterLowering : public OpRewritePattern<DescriptorScatterOp> {
156148

157149
LogicalResult matchAndRewrite(DescriptorScatterOp op,
158150
PatternRewriter &rewriter) const override {
159-
auto createStore = [&](Value tmaPtr, Value alloc) {
160-
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(),
161-
tmaPtr, op.getXOffsets(),
151+
auto createStore = [&](Value desc, Value alloc) {
152+
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc,
153+
op.getXOffsets(),
162154
op.getYOffset(), alloc);
163155
};
164156
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu;
77

88
namespace mlir::triton::nvidia_gpu {
99

10-
SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
11-
Attribute encoding,
12-
SmallVector<Value> indices) {
13-
if (isFp4Padded(encoding)) {
14-
auto two = arith::ConstantIntOp::create(builder, loc, 2, 32);
15-
indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two);
16-
}
17-
return indices;
18-
}
19-
2010
ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout,
2111
ArrayRef<int64_t> shape) {
2212
auto rank = shape.size();

python/test/unit/language/test_matmul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device)
380380
pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 or above")
381381
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
382382
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
383+
if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256:
384+
pytest.skip("Config requires too much shared memory")
383385

384386
if BLOCK_N == 256 and BLOCK_K == 256:
385387
NUM_STAGES = min(NUM_STAGES, 2)
@@ -1204,6 +1206,8 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
12041206
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
12051207
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
12061208
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
1209+
if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256:
1210+
pytest.skip("Config requires too much shared memory")
12071211
elif is_xpu():
12081212
if not is_xpu_cri() and not (WITH_A_SCALE and WITH_B_SCALE):
12091213
pytest.xfail("None scale has not been tested on XPU backend")
@@ -1367,7 +1371,7 @@ def batched_mxfp_matmul( #
13671371

13681372

13691373
@pytest.mark.parametrize("BATCH_SIZE, BLOCK_BATCH_SIZE", [(1, 1), (16, 1), (16, 4)])
1370-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128)])
1374+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 64, 128)])
13711375
@pytest.mark.parametrize("NUM_STAGES", [1, 2 if is_hip() else 3])
13721376
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
13731377
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if (is_hip_cdna() or is_hip_gfx1250()) else [0]))
@@ -1383,6 +1387,8 @@ def test_batched_mxfp(BATCH_SIZE, BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_N, BLOCK_K, N
13831387
pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 and above")
13841388
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
13851389
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
1390+
if is_hip_cdna4() and NUM_STAGES > 1 and max(BLOCK_M, BLOCK_N) > 64:
1391+
pytest.skip("Config requires too much shared memory")
13861392
elif is_xpu():
13871393
if BLOCK_BATCH_SIZE == 4 and BLOCK_N == 64:
13881394
pytest.skip("FIXME: #5762")

python/triton/knobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,8 @@ class amd_knobs(base_knobs):
579579
# We use strs so that we can have a default value based on other runtime info
580580
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
581581
use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
582+
use_async_copy: env_opt_bool = env_opt_bool("TRITON_HIP_USE_ASYNC_COPY")
582583

583-
use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
584584
scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
585585

586586

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,12 @@ def make_default_opt_flags_amd(
211211
num_stages = 1
212212

213213
# specific configs for F16 x MXFP4 on CDNA4
214-
# Note that these configs will exceed LDS usage with async copy enabled
215214
if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None:
216215
split_k = 1
217216
if m <= 1024:
218217
target_kernel_kwargs["waves_per_eu"] = 3
219218
block_n = 128
220-
block_k = 256
219+
block_k = 128
221220
num_warps = 4
222221
else:
223222
target_kernel_kwargs["waves_per_eu"] = 0

0 commit comments

Comments
 (0)