-
Notifications
You must be signed in to change notification settings - Fork 290
ROCm mx-fp8 Gemm #2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ROCm mx-fp8 Gemm #2066
Conversation
…dation logic. Added MXFP8_HIPBLASLT recipe and adjusted mx_mm function to accommodate new kernel options.
…ASLT kernel choice for mxfp8 gemm. Enhance documentation on end-to-end performance optimization efforts for AMD GPUs.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2066
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled JobAs of commit 012f938 with merge base 801af03 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…py to include HIPBLASLT as a valid kernel choice for MX FP8 operations.
related to pytorch/pytorch#151360 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds support for AMD MI355x GPUs by introducing HIPBLASLT kernels into the MX formats prototype, along with necessary config updates, validation logic, and documentation enhancements.
- Extend
MXGemmKernelChoice
andMXLinearRecipeName
to includeHIPBLASLT
- Add validation and dispatch logic for HIPBLASLT in config and
mx_ops
- Update README to show how to use HIPBLASLT on AMD MI355x hardware
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
config.py | Added HIPBLASLT enum, recipe, and validation in _validate_gemm_kernel_choice and from_recipe_name |
mx_ops.py | Extended GEMM dispatch and FP8 assertions to support HIPBLASLT |
README.md | Documented AMD MI355x support and HIPBLASLT usage |
Comments suppressed due to low confidence (3)
torchao/prototype/mx_formats/config.py:35
- Comment indicates ROCm 7.0 requirement, but PR description specifies ROCm 6.5+. Consider aligning version requirement in code comments and documentation.
# available only on ROCm with HIPBLASLT support, reuqire gfx950 and ROCm 7.0
torchao/prototype/mx_formats/mx_ops.py:91
- [nitpick] New HIPBLASLT code path added to GEMM dispatch; consider adding or updating tests to cover this scenario in
_addmm_mx_dispatch
.
if gemm_choice in (
torchao/prototype/mx_formats/config.py:71
- [nitpick] New HIPBLASLT validation logic in
_validate_gemm_kernel_choice
should be covered by tests to verify block size, dtype, and HIP availability checks.
elif gemm_kernel_choice == MXGemmKernelChoice.HIPBLASLT:
… HIPBLASLT are supported kernel choices for MX FP8 operations.
…l choices for MX FP8 operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool / looks good and scaled_mm is ultimately whats backing this right? can you add a test even if its not in our CI/CD
Co-authored-by: Copilot <[email protected]>
- Introduced `is_ROCm_mx_supported` function to verify ROCm environment compatibility for MX operations. - Added `test_hipblaslt_fp8` to validate FP8 operations using the HIPBLASLT backend, including SQNR verification for output accuracy. - Updated imports in `test_mx_mm.py` to include necessary utilities for the new test.
- Replaced `compute_sqnr` with `compute_error` for improved accuracy in error measurement. - Updated assertion to ensure output accuracy meets the specified threshold.
Thanks. |
- Updated the function to ensure `torch.version.hip` is not None before checking the version, improving robustness against potential NoneType errors.
- Reformatted the return statement to enhance clarity and maintainability of the code.
TLDR: This pull request introduces support for AMD MI355x GPUs with HIPBLASLT kernels in the MX formats prototype. Note that this feature requires ROCm 7.0+ and gfx950
alongside several updates to improve compatibility and functionality for these GPUs. Key changes include updates to configuration options, validation logic, and GEMM kernel handling to integrate HIPBLASLT support.
AMD MI355x GPU Support:
torchao/prototype/mx_formats/config.py
:HIPBLASLT
as a newMXGemmKernelChoice
and included it in theMXLinearRecipeName
for configuration presets. [1] [2]_validate_gemm_kernel_choice
to include validation logic forHIPBLASLT
, ensuring proper block size, data type, and ROCm availability.torchao/prototype/mx_formats/mx_ops.py
:mx_mm
to supportHIPBLASLT
for scaled matrix multiplication and real GEMM operations. [1] [2]Documentation Updates:
torchao/prototype/mx_formats/README.md
:Minor Code Refinements:
torchao/prototype/mx_formats/mx_ops.py
:mx_view_op
by reformatting conditions for FP6 element packing.