Skip to content

Parallelize bf16->f32 conversion for gemm(bf16:bf16->bf16) #147864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

aditew01
Copy link
Collaborator

@aditew01 aditew01 commented Feb 25, 2025

Improves performance for at::addmm / linear kernels when executed in dtype=bfloat16 and when SBGEMM is available.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @malfet @snadampal @milpuz01

@aditew01 aditew01 added module: cpu CPU specific problem (e.g., perf, algorithm) module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 topic: not user facing topic category labels Feb 25, 2025
Copy link

pytorch-bot bot commented Feb 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147864

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 257b0df with merge base d0f08dc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@aditew01 aditew01 requested review from jgong5, malfet and peterbell10 and removed request for jgong5 and malfet February 25, 2025 17:04
}
at::parallel_for(0, c_size, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
*(c++) = c10::convert<at::BFloat16>(float_v[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it even be faster if we do a vectorized type cast here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it'd be faster. I'm looking at how to plug the aten::vec in this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 please take a look.

@aditew01 aditew01 requested a review from jgong5 February 28, 2025 17:10
*(c++) = c10::convert<at::BFloat16>(float_v[i]);
int64_t i = begin;
//Vectorized Loop
for (; i + c_size <= end; i += c_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense, it will only ever take at most 1 trip since c_size is the upper bound for the loop.

Suggested change
for (; i + c_size <= end; i += c_size) {
for (; i + c_size <= end; i += Vectorized<float>::size()) {

for (auto cv: float_v) {
*(c++) = c10::convert<at::BFloat16>(cv);
}
at::parallel_for(0, c_size, 1, [&](int64_t begin, int64_t end) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the grain size would be at::internal::GRAIN_SIZE which avoids introducing threading overhead for very small tensors.

int64_t i = begin;
//Vectorized Loop
for (; i + c_size <= end; i += c_size) {
auto a_vec = at::vec::Vectorized<float>::loadu(&float_v[i]); // Load vec_size floats
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Vectorized outside of the ATen/native/cpu/ directory will only use SSE. You would need to have a cpu kernel behind a DispatchStub to get AVX2 or AVX512 support.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 3, 2025
@aditew01
Copy link
Collaborator Author

[Close] in favour of this: OpenMathLib/OpenBLAS#5155

@aditew01 aditew01 closed this Mar 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants