Skip to content

ROCm mx-fp8 Gemm #2066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp4_bf16
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_ROCm_mx_supported,
is_sm_at_least_100,
)

Expand Down Expand Up @@ -57,6 +59,41 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
return compute_error(out_hp, out).item()


@pytest.mark.skipif(
not is_ROCm_mx_supported(),
reason="AMD mxfloat8 test requires ROCm 7.0 on gfx950 GPU",
)
def test_hipblaslt_fp8():
"""Test HIPBLASLT backend for FP8 operations"""
a = torch.randn(128, 128, device="cuda")
b = torch.randn(128, 128, device="cuda")

a_mx = MXTensor.to_mx(
a, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT
)
b_mx = MXTensor.to_mx(
b, torch.float8_e4m3fn, gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT
)

# Compute reference result in high precision
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)

# Compute result using HIPBLASLT backend with scaled_mm
out = torch._scaled_mm(
a_mx._data,
b_mx._data.transpose(-1, -2),
a_mx._scale_e8m0.view(torch.float8_e8m0fnu),
b_mx._scale_e8m0.view(torch.float8_e8m0fnu),
out_dtype=torch.bfloat16,
)

# Verify results TODO: ROCm specific threshold
sqnr = compute_error(out_hp, out).item()
assert sqnr > 80.0, f"SQNR {sqnr} below threshold 80.0"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
Expand Down
10 changes: 8 additions & 2 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# MX training and inference with native PyTorch

This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 and AMD MI355x hardware.


## Overall status

Expand Down Expand Up @@ -29,6 +30,9 @@ from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS

# on AMD MI355x GPUs with ROCm 6.5+ and gfx950, you can use HIPBLASLT mxfp8 kernels
gemm_kernel_choice = MXGemmKernelChoice.HIPBLASLT

# on older NVIDIA gpus, you can run training with emulated MX gemm
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED

Expand Down Expand Up @@ -97,6 +101,8 @@ on supported hardware, you can run the following command:
// example output: https://gist.github.com/vkuzo/a1ddb782e6e1c2aef0c726b3df99efbc
```

On AMD MI355x GPUs with ROCm 6.5+ and gfx950, we use HIPBLASLT for mxfp8 gemm. We are actively working on optimizing the end-to-end performance for AMD hardware.

## to_mx cast across dim0 and dim1

On NVIDIA B200 machines, our to_mx kernels for mxfp8 achieve **up to 5.5 TB/s** for the dim0 cast (with torch.compile),
Expand Down
15 changes: 15 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ class MXGemmKernelChoice(Enum):
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
CUBLAS = "cublas"

# available only on ROCm with HIPBLASLT support, require gfx950 and ROCm 7.0
HIPBLASLT = "hipblaslt"


# Pre-made recipes for common configurations
class MXLinearRecipeName(Enum):
MXFP8_EMULATED = "mxfp8_emulated"
MXFP8_CUBLAS = "mxfp8_cublas"
MXFP8_HIPBLASLT = "mxfp8_hipblaslt"
MXFP4_EMULATED = "mxfp4_emulated"
MXFP4_CUTLASS = "mxfp4_cutlass"

Expand Down Expand Up @@ -64,6 +68,15 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT:
assert block_size == 32, (
f"block_size must be 32 to use the HIPBLASLT MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the HIPBLASLT MX gemm kernels, got {elem_dtype}"
)
assert torch.version.hip is not None, "HIPBLASLT requires ROCm"


@dataclass
Expand Down Expand Up @@ -124,6 +137,8 @@ def from_recipe_name(
return MXLinearConfig()
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
elif recipe_name is MXLinearRecipeName.MXFP8_HIPBLASLT:
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.HIPBLASLT)
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
Expand Down
13 changes: 10 additions & 3 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def _addmm_mx_dispatch(
"""
gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice)

if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
if gemm_choice in (
MXGemmKernelChoice.CUBLAS,
MXGemmKernelChoice.CUTLASS,
MXGemmKernelChoice.HIPBLASLT,
):
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert a._data.is_contiguous()
Expand All @@ -103,8 +107,11 @@ def _addmm_mx_dispatch(

if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
assert gemm_choice is MXGemmKernelChoice.CUBLAS, (
"CUBLAS is the only supported kernel choice for MX FP8 operations"
assert gemm_choice in (
MXGemmKernelChoice.CUBLAS,
MXGemmKernelChoice.HIPBLASLT,
), (
"CUBLAS and HIPBLASLT are the only supported kernel choices for MX FP8 operations ATM"
)

res = torch._scaled_mm(
Expand Down
17 changes: 17 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,20 @@ def is_package_at_least(package_name: str, min_version: str):
return False

return version(package_name) >= min_version


def is_ROCm_mx_supported() -> bool:
"""
Check if the current environment supports ROCm MX operations.
This requires:
1. ROCm platform
2. gfx950 GPU (MI350)
3. ROCm 7.0
"""
return all(
[
is_ROCM(),
is_MI350(),
torch.version.hip is not None and torch.version.hip.startswith("7.0"),
]
)
Loading