Skip to content

[Inductor][float8] Support qlinear for float8 in inductor #2565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

shiyang-weng
Copy link
Contributor

@shiyang-weng shiyang-weng commented Jul 17, 2025

For float8_e4m3fn, support

register_qlinear_weight_prepack
_register_qlinear_unary_fusion
_register_qlinear_binary_fusion
quant_lift_up

on inductor.

For FP8, there are following issues

  1. q/dq switch to use quantize_affine_float8/dequantize_affine_float8
  2. The q/dq API change. The fp8 q/dq requires type(scale) is tensor.
  3. pt2e not support float8.

Based on these issues,

  1. Need to handle fp8 q/dq pattern separately.
  2. Handle scale separately.
  3. We implement the function(fp8_convert_), which can add q/dq before the linear in the model. We add the function to test/quantization/pt2e/test_x86inductor_fusion.py

Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2565

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shiyang-weng shiyang-weng marked this pull request as draft July 17, 2025 02:59
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 17, 2025
return (
len(node.all_input_nodes) == 2
and node.all_input_nodes[1].target == torch.tensor
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add return False

Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Comment on lines -1397 to +1509
for bias in [True, False]:
self._qlinear_test_helper((torch.randn((2, 4)),), bias=bias)
for is_fp8 in [True, False]:
for bias in [True, False]:
self._qlinear_test_helper(
(torch.randn((2, 4)),), bias=bias, is_fp8=is_fp8
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to fp8 stuff in separate tests, i.e., keeping test_qlinear_cpu and adding test_fp8_qlinear_cpu. Same for other tests.

@@ -1804,13 +1940,166 @@ def test_qlinear_add_int8_mixed_bf16(self, use_relu, is_qat, is_dynamic):
is_dynamic=is_dynamic,
)

def _fp8_qlinear_add_test_helper(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between the int8 version and the fp8 version? Can we merge them?

lambda x, y: x.add_(y),
lambda x, y: y.add_(x),
]
is_fp8 = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are defining a dedicated helper for fp8, is this still needed?

@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("input_dim_exceeds_two", [True, False])
@parametrize("check_reuse_input", [True, False])
def test_fp8_qlinear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this test case and the test_qlinear_cpu above?

Comment on lines +2517 to +2520
for is_fp8 in [True, False]:
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
is_bf16 = original_pattern_output_dtype == torch.bfloat16
for x_scale_zp_are_tensors in (False, True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use itertools.product maybe?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants