Skip to content

Improved performance of the fp4tofp conversion #4299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AndreyPavlenko
Copy link
Contributor

@AndreyPavlenko AndreyPavlenko commented May 24, 2025

Use a simple lookup table instead of explicit conversion.

Fixes #4298

This implementation creates 3 constants:

    %32 = llvm.mlir.constant(dense<[0.000000e+00, 5.000000e-01, 1.000000e+00, 1.500000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 6.000000e+00, -0.000000e+00, -5.000000e-01, -1.000000e+00, -1.500000e+00, -2.000000e+00, -3.000000e+00, -4.000000e+00, -6.000000e+00]> : vector<16xbf16>) : vector<16xbf16>
    %33 = llvm.mlir.constant(4 : i8) : i8
    %34 = llvm.mlir.constant(15 : i8) : i8

and 4 operations per each pair of values:

    %35 = llvm.and %0, %34 : i8
    %36 = llvm.lshr %0, %33 : i8
    %37 = llvm.extractelement %32[%35 : i8] : vector<16xbf16>
    %38 = llvm.extractelement %32[%36 : i8] : vector<16xbf16>

I've not compared the performance, but it seems more efficient than #4298 .

@AndreyPavlenko AndreyPavlenko marked this pull request as ready for review May 24, 2025 11:46
@LiyangLingIntel
Copy link
Contributor

We may need to remove these lines

if A_DATA_TYPE == "float4" and B_DATA_TYPE == "float4":
to test this case in CI.
Can you explain the logic of lookup table a little bit?

@AndreyPavlenko
Copy link
Contributor Author

We may need to remove these lines

It does not resolve the [ZE]0x78000011 error.

@AndreyPavlenko
Copy link
Contributor Author

AndreyPavlenko commented May 26, 2025

Can you explain the logic of lookup table a little bit?

All the possible values are enumerated in the constant vector. Instead of doing all these bit manipulations and conditionals, we just extract the elements by index. 0b0001 is 0.5, it's stored at index 1, 0b0010 is 1.0 and stored at index 2 ... and so on.

@AndreyPavlenko AndreyPavlenko force-pushed the AndreyPavlenko/fp4tofp branch from eabc278 to a5901fc Compare May 26, 2025 10:50
@LiyangLingIntel
Copy link
Contributor

We may need to remove these lines

It does not resolve the [ZE]0x78000011 error.

We still need the performance result to see how much perf gain we can get from this change.

@AndreyPavlenko AndreyPavlenko force-pushed the AndreyPavlenko/fp4tofp branch from a5901fc to cc16289 Compare May 27, 2025 20:05
@etiotto etiotto requested review from chengjunlu and ienkovich May 28, 2025 13:15
Value idx1 = b.and_(v, i8_15);
Value idx2 = b.lshr(v, i8_4);
results.push_back(b.extract_element(table, idx1));
results.push_back(b.extract_element(table, idx2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?

Copy link
Contributor

@chengjunlu chengjunlu May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to make sure two things before approving this changes.

  1. On Xe2-3, the Register-Indirect Register Addressing is used for indexing the value in the look up table which resident in register. No spilling to memory is used in IGC codegen for extract value with varialbe.
  2. On Xe4 (or compare to CUDA), is the new arch supporting the same feature as Register-Indirect Register Addressing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what these extractions are translated into in the machine code. Can we keep the constant table in a register, one value per lane, and use a shuffle idx instruction instead of element extraction to get it? Would it be more efficient?

Here is the assembly of the elements extraction:

//.declare V4328 (4339)  rf=r size=32 type=w align=32 words (r1.0)
...
//.declare  (8887)  rf=r size=8 type=uq alias=V4328+0 align=32 words (r1.0)
//.declare  (8888)  rf=r size=8 type=uq alias=V4328+8 align=32 words (r1.1)
//.declare  (8889)  rf=r size=8 type=uq alias=V4328+16 align=32 words (r1.2)
//.declare  (8890)  rf=r size=8 type=uq alias=V4328+24 align=32 words (r1.3)

...

// Line 18
(W)     and (1|M0)               r1.16<1>:w    r37.3<0;1,0>:w    15:w                                //  ALU pipe: int; $176

// Line 26
(W)     mov (1|M0)               r1.32<2>:b    r1.16<0;1,0>:w                   {I@1}                //  ALU pipe: int; $8149
(W)     mov (1|M0)               r4.3<1>:w     r[a0.0]<0;1,0>:w                                      //  ALU pipe: int; $8148
(W)     mul (1|M0)               r1.16<1>:uw   r1.32<0;1,0>:b    0x2:uw              {I@2}           //  ALU pipe: int; $8150
(W)     add (1|M0)               a0.0<1>:uw    r1.16<0;1,0>:uw   0x40:uw              {A@1}          //  ALU pipe: int; src1 is addr of V4328(r1.0:w); $8151

// Line 18
(W)     shr (1|M0)               r1.16<1>:uw   r37.4<0;1,0>:uw   4:w                                 //  ALU pipe: int; $175

// Line 26
(W)     mov (1|M0)               r1.32<2>:b    r1.16<0;1,0>:w                   {I@1}                //  ALU pipe: int; $8153
(W)     mov (1|M0)               r4.4<1>:w     r[a0.0]<0;1,0>:w                                      //  ALU pipe: int; $8152
(W)     mul (1|M0)               r1.16<1>:uw   r1.32<0;1,0>:b    0x2:uw              {I@2}           //  ALU pipe: int; $8154
(W)     add (1|M0)               a0.0<1>:uw    r1.16<0;1,0>:uw   0x40:uw              {A@1}          //  ALU pipe: int; src1 is addr of V4328(r1.0:w); $8155

Use a simple lookup table instead of explicit conversion.

Fixes #4298
@AndreyPavlenko AndreyPavlenko force-pushed the AndreyPavlenko/fp4tofp branch from cc16289 to 671b879 Compare May 28, 2025 15:53
@AndreyPavlenko
Copy link
Contributor Author

We still need the performance result to see how much perf gain we can get from this change.

Here is a bench with a simple kernel, that converts a tensor with 512 elements:

import os
import timeit
import torch
import triton
import pathlib

def bench_fp4_to_bf16(device=triton.runtime.driver.active.get_active_torch_device()):
    ir = """
#blocked1 = #ttg.blocked<{sizePerThread = [512], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1024], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
module attributes {triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 16384 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {

  tt.func public @fp4_to_bf16_kernel(%in_ptr: !tt.ptr<i8>, %out_ptr: !tt.ptr<bf16>) {
    %c512 = arith.constant 512 : i32
    %c1024 = arith.constant 1024 : i32
    %pid = tt.get_program_id x : i32
    %in_off = arith.muli %pid, %c512 : i32
    %in_offsets = tt.splat %in_off : i32 -> tensor<512xi32, #blocked1>
    %in_range = tt.make_range {start = 0 : i32, end = 512 : i32} : tensor<512xi32, #blocked1>
    %in_ranges = arith.addi %in_range, %in_offsets : tensor<512xi32, #blocked1>
    %in_base = tt.splat %in_ptr : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>, #blocked1>
    %in_ptrs = tt.addptr %in_base, %in_ranges : tensor<512x!tt.ptr<i8>, #blocked1>, tensor<512xi32, #blocked1>
    %in_tensor = tt.load %in_ptrs : tensor<512x!tt.ptr<i8>, #blocked1>

    %bf16 = ttg.fp4_to_fp %in_tensor {axis = 0 : i32} : tensor<512xi8, #blocked1> -> tensor<1024xbf16, #blocked2>

    %out_off = arith.muli %pid, %c1024 : i32
    %out_offsets = tt.splat %out_off : i32 -> tensor<1024xi32, #blocked2>
    %out_range = tt.make_range {start = 0 : i32, end = 1024 : i32} : tensor<1024xi32, #blocked2>
    %out_ranges = arith.addi %out_range, %out_offsets : tensor<1024xi32, #blocked2>
    %out_base = tt.splat %out_ptr : !tt.ptr<bf16> -> tensor<1024x!tt.ptr<bf16>, #blocked2>
    %out_ptrs = tt.addptr %out_base, %out_ranges : tensor<1024x!tt.ptr<bf16>, #blocked2>, tensor<1024xi32, #blocked2>
    tt.store %out_ptrs, %bf16 : tensor<1024x!tt.ptr<bf16>, #blocked2>
    tt.return
  }
}
"""
    tmp_path: pathlib.Path = pathlib.Path(".")
    temp_file = tmp_path / "fp4_to_bf16_kernel.ttgir"
    temp_file.write_text(ir)
    kernel = triton.compile(str(temp_file))
    os.remove(temp_file)

    x = torch.randint(0, 127, (8192,), dtype=torch.int8, device=device)
    y = torch.zeros((16384,), dtype=torch.bfloat16, device=device)

    def run_kernel():
        kernel[(16, 1, 1)](x, y)
        torch.xpu.synchronize(device)
    
    run_kernel()
    print(x)
    print(y)

    time = timeit.timeit(run_kernel, number=100000)
    print(f"Time: {time}")


if __name__ == "__main__":
    bench_fp4_to_bf16()

The results on the main branch:

Time: 8.7180597379338
968K	/home/jovyan/.triton/cache/UMZ6C4F22VQA2NUICIKGOALVJEZUI35S33O7UOQNVCETBUM2KAKQ/@fp4_to_bf16_kernel.llir
548K	/home/jovyan/.triton/cache/UMZ6C4F22VQA2NUICIKGOALVJEZUI35S33O7UOQNVCETBUM2KAKQ/@fp4_to_bf16_kernel.spv

The new implementation:

Time: 4.980347302975133
544K	/home/jovyan/.triton/cache/SA43HXRVSFCAKZ7AQBXTC4JZ67JR7PKNLR7PZOYC6RTOMZPMEYWQ/@fp4_to_bf16_kernel.llir
212K	/home/jovyan/.triton/cache/SA43HXRVSFCAKZ7AQBXTC4JZ67JR7PKNLR7PZOYC6RTOMZPMEYWQ/@fp4_to_bf16_kernel.spv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The performance of fp4_to_fp could be improved
4 participants