-
Notifications
You must be signed in to change notification settings - Fork 282
[CPU] Enable DA8W4 on CPU #2128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2c5a799 with merge base 35ffb26 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@leslie-fang-intel This PR is updated to use a new layout. Please review again. Thanks. |
Hi @jerryzh168 Could you please review this PR? Thanks. |
2 similar comments
Hi @jerryzh168 Could you please review this PR? Thanks. |
Hi @jerryzh168 Could you please review this PR? Thanks. |
Hi @leslie-fang-intel Please review this PR again. I have also added the kernel code in this PR. It showed reasonable performance in internal benchmarks. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also describe how we choose different implementations based on the CPU Info.
torchao/csrc/cpu/da8w4_linear.cpp
Outdated
if (use_cpublas_checked) { | ||
return use_cpublas; | ||
} | ||
use_cpublas = at::native::cpublas::could_pack(at::kByte); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It requires AMX
to make this check as True
but in the setup.py
which only requires vnni
. I think we should align these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. We don't need the AMX flag here because the code in Torchao does not depends on AMX. brgemm
is in torch core, not here. We need AVX512_VNNI because we use AVX512_VNNI intrinsics explicitly here in Torchao.
Another example is the INT8 SDPA implementation, where only AVX512 flag is used because it only has AVX512 code in Torchao and it also uses brgemm from torch core.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be we can change the name of this function to da8w4_use_cpublas
? If it means packing the weight to the format for cpublas is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I have changed it to cpublas_can_pack
.
I have added more details in the description. Thanks. |
torchao/csrc/cpu/da8w4_linear.cpp
Outdated
if (use_cpublas_checked) { | ||
return use_cpublas; | ||
} | ||
use_cpublas = at::native::cpublas::could_pack(at::kByte); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be we can change the name of this function to da8w4_use_cpublas
? If it means packing the weight to the format for cpublas is needed.
} | ||
#endif | ||
|
||
#if defined(CPU_CAPABILITY_AVX512_VNNI) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it be a issue if user build the package on platform with VNNI support but run it on legacy platform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. I have added a runtime check.
Hi @jerryzh168 Could you please review this PR? Thanks. It's changed a lot since your last review. |
Summary
This PR enables DA8W4 on CPU.
Int8DynamicActInt4WeightCPULayout
and its implementationda8w4_linear_prepack_cpu
for weight packingda8w4_linear_cpu
for A8W4 GEMM.The feature supports symmetric and asymmetric quantization of activation.
The ops and kernels won't be available unless
USE_CPP_KERNELS=1
on Linux with an X86 CPU with AVX512.To get the best performance, one needs a CPU with AMX support.
Implementation details
at::cpublas
brgemm utilities from Pytorch core if available.Usage
Test plan