Skip to content

ROCm mx-fp8 Gemm #2066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

ROCm mx-fp8 Gemm #2066

wants to merge 18 commits into from

Conversation

petrex
Copy link
Collaborator

@petrex petrex commented Apr 16, 2025

TLDR: This pull request introduces support for AMD MI355x GPUs with HIPBLASLT kernels in the MX formats prototype. Note that this feature requires ROCm 7.0+ and gfx950

alongside several updates to improve compatibility and functionality for these GPUs. Key changes include updates to configuration options, validation logic, and GEMM kernel handling to integrate HIPBLASLT support.

AMD MI355x GPU Support:

  • torchao/prototype/mx_formats/config.py:

    • Added HIPBLASLT as a new MXGemmKernelChoice and included it in the MXLinearRecipeName for configuration presets. [1] [2]
    • Updated _validate_gemm_kernel_choice to include validation logic for HIPBLASLT, ensuring proper block size, data type, and ROCm availability.
  • torchao/prototype/mx_formats/mx_ops.py:

    • Extended mx_mm to support HIPBLASLT for scaled matrix multiplication and real GEMM operations. [1] [2]
    • Adjusted error messaging for unsupported kernel choices in FP4 operations.

Documentation Updates:

  • torchao/prototype/mx_formats/README.md:
    • Updated the README to reflect AMD MI355x GPU support, including instructions for using HIPBLASLT kernels and ongoing optimization efforts for AMD hardware. [1] [2] [3]

Minor Code Refinements:

  • torchao/prototype/mx_formats/mx_ops.py:
    • Improved readability in mx_view_op by reformatting conditions for FP6 element packing.

petrex added 2 commits April 16, 2025 15:59
…dation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options.
…ASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs.
Copy link

pytorch-bot bot commented Apr 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2066

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job

As of commit 012f938 with merge base 801af03 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2025
@petrex petrex added the mx label Apr 17, 2025
@petrex petrex requested a review from vkuzo April 18, 2025 16:59
@petrex petrex added topic: new feature Use this tag if this PR adds a new feature ciflow/rocm labels Apr 18, 2025
@petrex
Copy link
Collaborator Author

petrex commented Jun 9, 2025

related to pytorch/pytorch#151360

@petrex petrex requested review from drisspg and Copilot June 9, 2025 17:32
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds support for AMD MI355x GPUs by introducing HIPBLASLT kernels into the MX formats prototype, along with necessary config updates, validation logic, and documentation enhancements.

  • Extend MXGemmKernelChoice and MXLinearRecipeName to include HIPBLASLT
  • Add validation and dispatch logic for HIPBLASLT in config and mx_ops
  • Update README to show how to use HIPBLASLT on AMD MI355x hardware

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
config.py Added HIPBLASLT enum, recipe, and validation in _validate_gemm_kernel_choice and from_recipe_name
mx_ops.py Extended GEMM dispatch and FP8 assertions to support HIPBLASLT
README.md Documented AMD MI355x support and HIPBLASLT usage
Comments suppressed due to low confidence (3)

torchao/prototype/mx_formats/config.py:35

  • Comment indicates ROCm 7.0 requirement, but PR description specifies ROCm 6.5+. Consider aligning version requirement in code comments and documentation.
# available only on ROCm with HIPBLASLT support, reuqire gfx950 and ROCm 7.0

torchao/prototype/mx_formats/mx_ops.py:91

  • [nitpick] New HIPBLASLT code path added to GEMM dispatch; consider adding or updating tests to cover this scenario in _addmm_mx_dispatch.
if gemm_choice in (

torchao/prototype/mx_formats/config.py:71

  • [nitpick] New HIPBLASLT validation logic in _validate_gemm_kernel_choice should be covered by tests to verify block size, dtype, and HIP availability checks.
elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT:

petrex added 2 commits June 9, 2025 10:56
… HIPBLASLT are supported kernel choices for MX FP8 operations.
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool / looks good and scaled_mm is ultimately whats backing this right? can you add a test even if its not in our CI/CD

petrex and others added 4 commits June 9, 2025 12:04
- Introduced `is_ROCm_mx_supported` function to verify ROCm environment compatibility for MX operations.
- Added `test_hipblaslt_fp8` to validate FP8 operations using the HIPBLASLT backend, including SQNR verification for output accuracy.
- Updated imports in `test_mx_mm.py` to include necessary utilities for the new test.
- Replaced `compute_sqnr` with `compute_error` for improved accuracy in error measurement.
- Updated assertion to ensure output accuracy meets the specified threshold.
@petrex
Copy link
Collaborator Author

petrex commented Jun 9, 2025

This is cool / looks good and scaled_mm is ultimately whats backing this right? can you add a test even if its not in our CI/CD

Thanks.
Right . scale_mm() --> hipblaslt --> gfx950. I'd deploy gfx950s in CI once they are GA.
Added a test that is not currently run in CI.

- Updated the function to ensure `torch.version.hip` is not None before checking the version, improving robustness against potential NoneType errors.
- Reformatted the return statement to enhance clarity and maintainability of the code.
@petrex petrex self-assigned this Jun 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/rocm CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm mx topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants