-
Notifications
You must be signed in to change notification settings - Fork 62
Improved performance of the fp4tofp conversion #4299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We may need to remove these lines
Can you explain the logic of lookup table a little bit? |
It does not resolve the [ZE]0x78000011 error. |
All the possible values are enumerated in the constant vector. Instead of doing all these bit manipulations and conditionals, we just extract the elements by index. 0b0001 is 0.5, it's stored at index 1, 0b0010 is 1.0 and stored at index 2 ... and so on. |
eabc278
to
a5901fc
Compare
We still need the performance result to see how much perf gain we can get from this change. |
a5901fc
to
cc16289
Compare
Value idx1 = b.and_(v, i8_15); | ||
Value idx2 = b.lshr(v, i8_4); | ||
results.push_back(b.extract_element(table, idx1)); | ||
results.push_back(b.extract_element(table, idx2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make sure two things before approving this changes.
- On Xe2-3, the
Register-Indirect Register Addressing
is used for indexing the value in the look up table which resident in register. No spilling to memory is used in IGC codegen for extract value with varialbe. - On Xe4 (or compare to CUDA), is the new arch supporting the same feature as
Register-Indirect Register Addressing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?
Here is the assembly of the elements extraction:
//.declare V4328 (4339) rf=r size=32 type=w align=32 words (r1.0)
...
//.declare (8887) rf=r size=8 type=uq alias=V4328+0 align=32 words (r1.0)
//.declare (8888) rf=r size=8 type=uq alias=V4328+8 align=32 words (r1.1)
//.declare (8889) rf=r size=8 type=uq alias=V4328+16 align=32 words (r1.2)
//.declare (8890) rf=r size=8 type=uq alias=V4328+24 align=32 words (r1.3)
...
// Line 18
(W) and (1|M0) r1.16<1>:w r37.3<0;1,0>:w 15:w // ALU pipe: int; $176
// Line 26
(W) mov (1|M0) r1.32<2>:b r1.16<0;1,0>:w {I@1} // ALU pipe: int; $8149
(W) mov (1|M0) r4.3<1>:w r[a0.0]<0;1,0>:w // ALU pipe: int; $8148
(W) mul (1|M0) r1.16<1>:uw r1.32<0;1,0>:b 0x2:uw {I@2} // ALU pipe: int; $8150
(W) add (1|M0) a0.0<1>:uw r1.16<0;1,0>:uw 0x40:uw {A@1} // ALU pipe: int; src1 is addr of V4328(r1.0:w); $8151
// Line 18
(W) shr (1|M0) r1.16<1>:uw r37.4<0;1,0>:uw 4:w // ALU pipe: int; $175
// Line 26
(W) mov (1|M0) r1.32<2>:b r1.16<0;1,0>:w {I@1} // ALU pipe: int; $8153
(W) mov (1|M0) r4.4<1>:w r[a0.0]<0;1,0>:w // ALU pipe: int; $8152
(W) mul (1|M0) r1.16<1>:uw r1.32<0;1,0>:b 0x2:uw {I@2} // ALU pipe: int; $8154
(W) add (1|M0) a0.0<1>:uw r1.16<0;1,0>:uw 0x40:uw {A@1} // ALU pipe: int; src1 is addr of V4328(r1.0:w); $8155
Use a simple lookup table instead of explicit conversion. Fixes #4298
cc16289
to
671b879
Compare
Here is a bench with a simple kernel, that converts a tensor with 512 elements: import os
import timeit
import torch
import triton
import pathlib
def bench_fp4_to_bf16(device=triton.runtime.driver.active.get_active_torch_device()):
ir = """
#blocked1 = #ttg.blocked<{sizePerThread = [512], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1024], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
tt.func public @fp4_to_bf16_kernel(%in_ptr: !tt.ptr<i8>, %out_ptr: !tt.ptr<bf16>) {
%c512 = arith.constant 512 : i32
%c1024 = arith.constant 1024 : i32
%pid = tt.get_program_id x : i32
%in_off = arith.muli %pid, %c512 : i32
%in_offsets = tt.splat %in_off : i32 -> tensor<512xi32, #blocked1>
%in_range = tt.make_range {start = 0 : i32, end = 512 : i32} : tensor<512xi32, #blocked1>
%in_ranges = arith.addi %in_range, %in_offsets : tensor<512xi32, #blocked1>
%in_base = tt.splat %in_ptr : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>, #blocked1>
%in_ptrs = tt.addptr %in_base, %in_ranges : tensor<512x!tt.ptr<i8>, #blocked1>, tensor<512xi32, #blocked1>
%in_tensor = tt.load %in_ptrs : tensor<512x!tt.ptr<i8>, #blocked1>
%bf16 = ttg.fp4_to_fp %in_tensor {axis = 0 : i32} : tensor<512xi8, #blocked1> -> tensor<1024xbf16, #blocked2>
%out_off = arith.muli %pid, %c1024 : i32
%out_offsets = tt.splat %out_off : i32 -> tensor<1024xi32, #blocked2>
%out_range = tt.make_range {start = 0 : i32, end = 1024 : i32} : tensor<1024xi32, #blocked2>
%out_ranges = arith.addi %out_range, %out_offsets : tensor<1024xi32, #blocked2>
%out_base = tt.splat %out_ptr : !tt.ptr<bf16> -> tensor<1024x!tt.ptr<bf16>, #blocked2>
%out_ptrs = tt.addptr %out_base, %out_ranges : tensor<1024x!tt.ptr<bf16>, #blocked2>, tensor<1024xi32, #blocked2>
tt.store %out_ptrs, %bf16 : tensor<1024x!tt.ptr<bf16>, #blocked2>
tt.return
}
}
"""
tmp_path: pathlib.Path = pathlib.Path(".")
temp_file = tmp_path / "fp4_to_bf16_kernel.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))
os.remove(temp_file)
x = torch.randint(0, 127, (8192,), dtype=torch.int8, device=device)
y = torch.zeros((16384,), dtype=torch.bfloat16, device=device)
def run_kernel():
kernel[(16, 1, 1)](x, y)
torch.xpu.synchronize(device)
run_kernel()
print(x)
print(y)
time = timeit.timeit(run_kernel, number=100000)
print(f"Time: {time}")
if __name__ == "__main__":
bench_fp4_to_bf16() The results on the main branch:
The new implementation:
|
Use a simple lookup table instead of explicit conversion.
Fixes #4298
This implementation creates 3 constants:
and 4 operations per each pair of values:
I've not compared the performance, but it seems more efficient than #4298 .