-
Notifications
You must be signed in to change notification settings - Fork 81
Improved performance of the fp4tofp conversion #4299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
We may need to remove these lines
Can you explain the logic of lookup table a little bit? |
It does not resolve the [ZE]0x78000011 error. |
All the possible values are enumerated in the constant vector. Instead of doing all these bit manipulations and conditionals, we just extract the elements by index. 0b0001 is 0.5, it's stored at index 1, 0b0010 is 1.0 and stored at index 2 ... and so on. |
eabc278 to
a5901fc
Compare
We still need the performance result to see how much perf gain we can get from this change. |
a5901fc to
cc16289
Compare
| Value idx1 = b.and_(v, i8_15); | ||
| Value idx2 = b.lshr(v, i8_4); | ||
| results.push_back(b.extract_element(table, idx1)); | ||
| results.push_back(b.extract_element(table, idx2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure two things before approving this changes.
- On Xe2-3, the
Register-Indirect Register Addressingis used for indexing the value in the look up table which resident in register. No spilling to memory is used in IGC codegen for extract value with varialbe. - On Xe4 (or compare to CUDA), is the new arch supporting the same feature as
Register-Indirect Register Addressing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?
Here is the assembly of the elements extraction:
//.declare V4328 (4339) rf=r size=32 type=w align=32 words (r1.0)
...
//.declare (8887) rf=r size=8 type=uq alias=V4328+0 align=32 words (r1.0)
//.declare (8888) rf=r size=8 type=uq alias=V4328+8 align=32 words (r1.1)
//.declare (8889) rf=r size=8 type=uq alias=V4328+16 align=32 words (r1.2)
//.declare (8890) rf=r size=8 type=uq alias=V4328+24 align=32 words (r1.3)
...
// Line 18
(W) and (1|M0) r1.16<1>:w r37.3<0;1,0>:w 15:w // ALU pipe: int; $176
// Line 26
(W) mov (1|M0) r1.32<2>:b r1.16<0;1,0>:w {I@1} // ALU pipe: int; $8149
(W) mov (1|M0) r4.3<1>:w r[a0.0]<0;1,0>:w // ALU pipe: int; $8148
(W) mul (1|M0) r1.16<1>:uw r1.32<0;1,0>:b 0x2:uw {I@2} // ALU pipe: int; $8150
(W) add (1|M0) a0.0<1>:uw r1.16<0;1,0>:uw 0x40:uw {A@1} // ALU pipe: int; src1 is addr of V4328(r1.0:w); $8151
// Line 18
(W) shr (1|M0) r1.16<1>:uw r37.4<0;1,0>:uw 4:w // ALU pipe: int; $175
// Line 26
(W) mov (1|M0) r1.32<2>:b r1.16<0;1,0>:w {I@1} // ALU pipe: int; $8153
(W) mov (1|M0) r4.4<1>:w r[a0.0]<0;1,0>:w // ALU pipe: int; $8152
(W) mul (1|M0) r1.16<1>:uw r1.32<0;1,0>:b 0x2:uw {I@2} // ALU pipe: int; $8154
(W) add (1|M0) a0.0<1>:uw r1.16<0;1,0>:uw 0x40:uw {A@1} // ALU pipe: int; src1 is addr of V4328(r1.0:w); $8155cc16289 to
671b879
Compare
Here is a bench with a simple kernel, that converts a tensor with 512 elements: Detailsimport os
import timeit
import torch
import triton
import pathlib
def bench_fp4_to_bf16(device=triton.runtime.driver.active.get_active_torch_device()):
ir = """
#blocked1 = #ttg.blocked<{sizePerThread = [512], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1024], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
tt.func public @fp4_to_bf16_kernel(%in_ptr: !tt.ptr<i8>, %out_ptr: !tt.ptr<bf16>) {
%c512 = arith.constant 512 : i32
%c1024 = arith.constant 1024 : i32
%pid = tt.get_program_id x : i32
%in_off = arith.muli %pid, %c512 : i32
%in_offsets = tt.splat %in_off : i32 -> tensor<512xi32, #blocked1>
%in_range = tt.make_range {start = 0 : i32, end = 512 : i32} : tensor<512xi32, #blocked1>
%in_ranges = arith.addi %in_range, %in_offsets : tensor<512xi32, #blocked1>
%in_base = tt.splat %in_ptr : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>, #blocked1>
%in_ptrs = tt.addptr %in_base, %in_ranges : tensor<512x!tt.ptr<i8>, #blocked1>, tensor<512xi32, #blocked1>
%in_tensor = tt.load %in_ptrs : tensor<512x!tt.ptr<i8>, #blocked1>
%bf16 = ttg.fp4_to_fp %in_tensor {axis = 0 : i32} : tensor<512xi8, #blocked1> -> tensor<1024xbf16, #blocked2>
%out_off = arith.muli %pid, %c1024 : i32
%out_offsets = tt.splat %out_off : i32 -> tensor<1024xi32, #blocked2>
%out_range = tt.make_range {start = 0 : i32, end = 1024 : i32} : tensor<1024xi32, #blocked2>
%out_ranges = arith.addi %out_range, %out_offsets : tensor<1024xi32, #blocked2>
%out_base = tt.splat %out_ptr : !tt.ptr<bf16> -> tensor<1024x!tt.ptr<bf16>, #blocked2>
%out_ptrs = tt.addptr %out_base, %out_ranges : tensor<1024x!tt.ptr<bf16>, #blocked2>, tensor<1024xi32, #blocked2>
tt.store %out_ptrs, %bf16 : tensor<1024x!tt.ptr<bf16>, #blocked2>
tt.return
}
}
"""
tmp_path: pathlib.Path = pathlib.Path(".")
temp_file = tmp_path / "fp4_to_bf16_kernel.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))
os.remove(temp_file)
x = torch.randint(0, 127, (8192,), dtype=torch.int8, device=device)
y = torch.zeros((16384,), dtype=torch.bfloat16, device=device)
def run_kernel():
kernel[(16, 1, 1)](x, y)
torch.xpu.synchronize(device)
run_kernel()
print(x)
print(y)
time = timeit.timeit(run_kernel, number=100000)
print(f"Time: {time}")
if __name__ == "__main__":
bench_fp4_to_bf16()The results on the main branch: The new implementation: |
671b879 to
41291a9
Compare
41291a9 to
799c3c3
Compare
Use a simple lookup table instead of explicit conversion. Fixes #4298
799c3c3 to
108a611
Compare
Use a simple lookup table instead of explicit conversion. Use bitwise operations on vectors to build the indices. Fixes #4298 ------------ This is another (and perhaps overcomplicated) version of #4299, that uses bitwise operations on vectors (not per element) to build the indices. The resulting llir is the following: ```MLIR %idx_vec0 = lshr <16 x i8> %i8vec, splat (i8 4) %idx_vec1 = and <16 x i8> %i8vec, splat (i8 15) %idx0 = extractelement <16 x i8> %idx_vec0, i64 0 %bf0 = extractelement <16 x bfloat> <bfloat 0xR0000, bfloat 0xR3F00, bfloat 0xR3F80, bfloat 0xR3FC0, bfloat 0xR4000, bfloat 0xR4040, bfloat 0xR4080, bfloat 0xR40C0, bfloat 0xR8000, bfloat 0xRBF00, bfloat 0xRBF80, bfloat 0xRBFC0, bfloat 0xRC000, bfloat 0xRC040, bfloat 0xRC080, bfloat 0xRC0C0>, i8 %idx0 %idx1 = extractelement <16 x i8> %idx_vec1, i64 0 %bf1 = extractelement <16 x bfloat> <bfloat 0xR0000, bfloat 0xR3F00, bfloat 0xR3F80, bfloat 0xR3FC0, bfloat 0xR4000, bfloat 0xR4040, bfloat 0xR4080, bfloat 0xR40C0, bfloat 0xR8000, bfloat 0xRBF00, bfloat 0xRBF80, bfloat 0xRBFC0, bfloat 0xRC000, bfloat 0xRC040, bfloat 0xRC080, bfloat 0xRC0C0>, i8 %idx1 ... ``` --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Tiotto, Ettore <[email protected]>
Use a simple lookup table instead of explicit conversion.
Fixes #4298
This implementation creates 3 constants:
and 4 operations per each pair of values:
I've not compared the performance, but it seems more efficient than #4298 .