Skip to content

Commit 75175e7

Browse files
authored
[flang][cuda] Inline this_thread_block() calls (#146144)
1 parent 9a93de5 commit 75175e7

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ struct IntrinsicLibrary {
443443
fir::ExtendedValue genTranspose(mlir::Type,
444444
llvm::ArrayRef<fir::ExtendedValue>);
445445
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
446+
mlir::Value genThisThreadBlock(mlir::Type, llvm::ArrayRef<mlir::Value>);
446447
mlir::Value genThisWarp(mlir::Type, llvm::ArrayRef<mlir::Value>);
447448
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
448449
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ static constexpr IntrinsicHandler handlers[]{
933933
/*isElemental=*/false},
934934
{"tand", &I::genTand},
935935
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
936+
{"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false},
936937
{"this_warp", &I::genThisWarp, {}, /*isElemental=*/false},
937938
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
938939
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
@@ -8195,6 +8196,60 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
81958196
return res;
81968197
}
81978198

8199+
// THIS_THREAD_BLOCK
8200+
mlir::Value
8201+
IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType,
8202+
llvm::ArrayRef<mlir::Value> args) {
8203+
assert(args.size() == 0);
8204+
auto recTy = mlir::cast<fir::RecordType>(resultType);
8205+
assert(recTy && "RecordType expepected");
8206+
mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
8207+
mlir::Type i32Ty = builder.getI32Type();
8208+
8209+
// this_thread_block%size = blockDim.z * blockDim.y * blockDim.x;
8210+
mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
8211+
mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
8212+
mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
8213+
mlir::Value size =
8214+
builder.create<mlir::arith::MulIOp>(loc, blockDimZ, blockDimY);
8215+
size = builder.create<mlir::arith::MulIOp>(loc, size, blockDimX);
8216+
8217+
// this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) +
8218+
// (threadIdx.y * blockDim.x) + threadIdx.x + 1;
8219+
mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
8220+
mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
8221+
mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
8222+
mlir::Value r1 =
8223+
builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
8224+
mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, blockDimX);
8225+
mlir::Value r3 =
8226+
builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
8227+
mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
8228+
mlir::Value rank = builder.create<mlir::arith::AddIOp>(loc, r2r3, threadIdX);
8229+
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
8230+
rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
8231+
8232+
auto sizeFieldName = recTy.getTypeList()[1].first;
8233+
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
8234+
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
8235+
mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
8236+
loc, fieldIndexType, sizeFieldName, recTy,
8237+
/*typeParams=*/mlir::ValueRange{});
8238+
mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
8239+
loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
8240+
builder.create<fir::StoreOp>(loc, size, sizeCoord);
8241+
8242+
auto rankFieldName = recTy.getTypeList()[2].first;
8243+
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
8244+
mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
8245+
loc, fieldIndexType, rankFieldName, recTy,
8246+
/*typeParams=*/mlir::ValueRange{});
8247+
mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
8248+
loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
8249+
builder.create<fir::StoreOp>(loc, rank, rankCoord);
8250+
return res;
8251+
}
8252+
81988253
// THIS_WARP
81998254
mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType,
82008255
llvm::ArrayRef<mlir::Value> args) {

flang/module/cooperative_groups.f90

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,26 @@ module cooperative_groups
2626
integer(4) :: rank
2727
end type coalesced_group
2828

29+
type :: thread_group
30+
type(c_devptr), private :: handle
31+
integer(4) :: size
32+
integer(4) :: rank
33+
end type thread_group
34+
2935
interface
3036
attributes(device) function this_grid()
3137
import
3238
type(grid_group) :: this_grid
3339
end function
3440
end interface
3541

42+
interface
43+
attributes(device) function this_thread_block()
44+
import
45+
type(thread_group) :: this_thread_block
46+
end function
47+
end interface
48+
3649
interface this_warp
3750
attributes(device) function this_warp()
3851
import

flang/test/Lower/CUDA/cuda-cooperative.cuf

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,30 @@ end subroutine
7070
! CHECK: %[[AND:.*]] = arith.andi %[[THREAD_ID]], %[[C31]] : i32
7171
! CHECK: %[[RANK:.*]] = arith.addi %[[AND]], %[[C1]] : i32
7272
! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %{{.*}}, rank : (!fir.ref<!fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
73+
74+
attributes(grid_global) subroutine t1()
75+
use cooperative_groups
76+
type(thread_group) :: gg
77+
gg = this_thread_block()
78+
end subroutine
79+
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
80+
! CHECK: %[[THREAD_GROUP:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
81+
! CHECK: %[[NTID_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
82+
! CHECK: %[[NTID_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
83+
! CHECK: %[[NTID_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
84+
! CHECK: %[[SIZE_ZY:.*]] = arith.muli %[[NTID_Z]], %[[NTID_Y]] : i32
85+
! CHECK: %[[SIZE:.*]] = arith.muli %[[SIZE_ZY]], %[[NTID_X]] : i32
86+
! CHECK: %[[TID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
87+
! CHECK: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
88+
! CHECK: %[[TID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
89+
! CHECK: %[[RANK_ZY:.*]] = arith.muli %[[TID_Z]], %[[NTID_Y]] : i32
90+
! CHECK: %[[RANK_ZYX:.*]] = arith.muli %[[RANK_ZY]], %[[NTID_X]] : i32
91+
! CHECK: %[[RANK_YX:.*]] = arith.muli %[[TID_Y]], %[[NTID_X]] : i32
92+
! CHECK: %[[RANK_SUM1:.*]] = arith.addi %[[RANK_ZYX]], %[[RANK_YX]] : i32
93+
! CHECK: %[[RANK_SUM2:.*]] = arith.addi %[[RANK_SUM1]], %[[TID_X]] : i32
94+
! CHECK: %[[C1:.*]] = arith.constant 1 : i32
95+
! CHECK: %[[RANK:.*]] = arith.addi %[[RANK_SUM2]], %[[C1]] : i32
96+
! CHECK: %[[SIZE_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
97+
! CHECK: fir.store %[[SIZE]] to %[[SIZE_COORD]] : !fir.ref<i32>
98+
! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
7399
! CHECK: fir.store %[[RANK]] to %[[RANK_COORD]] : !fir.ref<i32>

0 commit comments

Comments
 (0)