Skip to content

Commit 9c5773b

Browse files
committed
Add documentation for BF16ToGpu pass.
1 parent 2556ae7 commit 9c5773b

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ repos:
66
hooks:
77
- id: end-of-file-fixer
88
- id: trailing-whitespace
9+
args: [--markdown-linebreak-ext=md]
910
- repo: https://github.com/pocc/pre-commit-hooks
1011
rev: v1.1.1
1112
hooks:

docs/Transforms/BF16ToGpu.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# BF16ToGPU Pass
2+
3+
4+
BF16ToGPU pass transforms gpu dialect with bf16 dtype to a form that can be lowered to
5+
spirv and execute on Intel GPUs.
6+
Since bf16 is not a type directly supported by spirv, bf16 type is passes as an i16 type to
7+
spirv functions. This requires changing gpu.launch_func (caller) and gpu.func (callee).
8+
9+
* Caller side changes are as follows:
10+
MLIR does not support direct casting of non scalar type, so caller side casting from
11+
bf16 type to i16 type bf16 is done indirectly by using memref.view
12+
gpu.alloc of bf16 type is replaced with gpu.alloc of i8 type. Then one view of bf16 and
13+
another view of i16 is created. Host side code that used original gpu.alloc is updated to
14+
use the bf16 view. And the i16 view is passed to gpu.launch_func, replacing old arguments.
15+
16+
* Callee side changes are as follows:
17+
gpu.func's bf16 arguments replaced with i16 arguments. bf16 type usage inside gpu kernel body
18+
can be divided into two different types. First, operations that interpret bits according to
19+
bf16 specification. Second, operations that only care about the size of the types.
20+
First type of operations include most of Arithmetic dialect and Math dialect operations
21+
expect for bit cast. The operation's bf16 operands and results types are replaced with f32.
22+
bf16 operands are replaced with a seqeuence i16 to bf16 bitcast followed by bf16 to f32 extf.
23+
bf16 results are replaced with a sequence of f32 to bf16 truncf followed by bf16 to i16 bitcast.
24+
The resulting code, in summary, has two parts.
25+
1) bf16 operations emulated by f32 operations with the help of widening and truncating.
26+
2) bitcast operations between bf16 and i16 type.
27+
28+
Second type of operations get bf16 operands and results type replaced with i16.
29+
30+
31+
## Example
32+
33+
```
34+
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
35+
%c20 = arith.constant 20 : index
36+
%c10 = arith.constant 10 : index
37+
%c1 = arith.constant 1 : index
38+
%memref = gpu.alloc host_shared () : memref<10x20xbf16>
39+
memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16>
40+
%memref_0 = gpu.alloc host_shared () : memref<10x20xbf16>
41+
memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16>
42+
%memref_1 = gpu.alloc host_shared () : memref<10x20xbf16>
43+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>)
44+
gpu.dealloc %memref_0 : memref<10x20xbf16>
45+
gpu.dealloc %memref : memref<10x20xbf16>
46+
return %memref_1 : memref<10x20xbf16>
47+
}
48+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>} {
49+
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
50+
%0 = gpu.block_id x
51+
%1 = gpu.block_id y
52+
%2 = memref.load %arg0[%0, %1] : memref<10x20xbf16>
53+
%3 = memref.load %arg1[%0, %1] : memref<10x20xbf16>
54+
%4 = arith.addf %2, %3 : bf16
55+
memref.store %4, %arg2[%0, %1] : memref<10x20xbf16>
56+
gpu.return
57+
}
58+
}
59+
```
60+
61+
The Pass will change the IR to:
62+
63+
```
64+
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
65+
%c20 = arith.constant 20 : index
66+
%c10 = arith.constant 10 : index
67+
%c1 = arith.constant 1 : index
68+
%c0 = arith.constant 0 : index
69+
%memref = gpu.alloc host_shared () : memref<400xi8>
70+
%view = memref.view %memref[%c0][] : memref<400xi8> to memref<10x20xbf16>
71+
%view_0 = memref.view %memref[%c0][] : memref<400xi8> to memref<10x20xi16>
72+
memref.copy %arg1, %view : memref<10x20xbf16> to memref<10x20xbf16>
73+
%c0_1 = arith.constant 0 : index
74+
%memref_2 = gpu.alloc host_shared () : memref<400xi8>
75+
%view_3 = memref.view %memref_2[%c0_1][] : memref<400xi8> to memref<10x20xbf16>
76+
%view_4 = memref.view %memref_2[%c0_1][] : memref<400xi8> to memref<10x20xi16>
77+
memref.copy %arg0, %view_3 : memref<10x20xbf16> to memref<10x20xbf16>
78+
%c0_5 = arith.constant 0 : index
79+
%memref_6 = gpu.alloc host_shared () : memref<400xi8>
80+
%view_7 = memref.view %memref_6[%c0_5][] : memref<400xi8> to memref<10x20xbf16>
81+
%view_8 = memref.view %memref_6[%c0_5][] : memref<400xi8> to memref<10x20xi16>
82+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%view_4 : memref<10x20xi16>, %view_0 : memref<10x20xi16>, %view_8 : memref<10x20xi16>)
83+
gpu.dealloc %memref_2 : memref<400xi8>
84+
gpu.dealloc %memref : memref<400xi8>
85+
return %view_7 : memref<10x20xbf16>
86+
}
87+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>} {
88+
gpu.func @test_kernel(%arg0: memref<10x20xi16>, %arg1: memref<10x20xi16>, %arg2: memref<10x20xi16>) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
89+
%0 = gpu.block_id x
90+
%1 = gpu.block_id y
91+
%2 = memref.load %arg0[%0, %1] : memref<10x20xi16>
92+
%3 = memref.load %arg1[%0, %1] : memref<10x20xi16>
93+
%4 = arith.bitcast %2 : i16 to bf16
94+
%5 = arith.extf %4 : bf16 to f32
95+
%6 = arith.bitcast %3 : i16 to bf16
96+
%7 = arith.extf %6 : bf16 to f32
97+
%8 = arith.addf %5, %7 : f32
98+
%9 = arith.truncf %8 : f32 to bf16
99+
%10 = arith.bitcast %9 : bf16 to i16
100+
memref.store %10, %arg2[%0, %1] : memref<10x20xi16>
101+
gpu.return
102+
}
103+
}
104+
```
105+
106+
107+
As shown in the example above, the memref.allocs in the IR are referring to device buffer allocation and hence they are replaced with gpu.alloc from the gpu dialect.
108+
109+
## Limitations of this pass.
110+
111+
112+
1. This pass only covers static shapes.
113+
2. This pass only supports scalar operations in kernel body.

0 commit comments

Comments
 (0)