Skip to content

Add activation sparsity (24 + fp8 dynamic quant) subclass #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented May 15, 2025

Summary:

This PR adds in the following things:

  • A new subclass + config for activation sparsity, Float8DynamicSemiSparseActivationFloat8WeightConfig, that works with LinearActivationQuantizedTensor.
  • Modifies CutlassSemiSparseLayout to support both weight and activation sparsity and adds a new impl + check.
  • adds in a kernel to accelerate $$xW^T$$ when x is sparse and we are memory bound. The kernel is an adapted version of https://github.com/FasterDecoding/TEAL

Test Plan:
pytest test/sparsity/test_activation.py

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

This PR adds in a kernel to accelerate $$xW^T$$ when x is sparse and we
are memory bound.

The idea here is that we can avoid loading the columns of $W$ that
correspond to the zero elements of $x$.

This lets us accelerate activation sparsity for bs=1 decode use cases.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented May 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2213

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit 61aedfd with merge base 1017c7e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 15, 2025
@jcaip jcaip added the topic: new feature Use this tag if this PR adds a new feature label May 22, 2025
@jcaip jcaip changed the title Add selective weight loading decode kernel for activation sparsity Add activation sparsity (24 + dynamic quant) subclass May 30, 2025
@jcaip jcaip changed the title Add activation sparsity (24 + dynamic quant) subclass Add activation sparsity (24 + fp8 dynamic quant) subclass May 30, 2025
)

aten = torch.ops.aten


def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

@jerryzh168 jerryzh168 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we don't do padding at this level, since this interferes with aliasing semantics, and make some op implementation impossible, like slice. is it possible to move this to the kernel itself? or improve the kernel to be able to handle the inputs without padding?

@@ -66,11 +97,12 @@ def __new__(
)
kwargs["dtype"] = sparse.dtype
kwargs["requires_grad"] = False
shape = (sparse.shape[0], 2 * sparse.shape[-1])
# shape = (sparse.shape[0], 2 * sparse.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

@@ -80,6 +112,7 @@ def __init__(
self.meta = meta
self.scale = scale
self._layout = _layout
self._shape = shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a separate shape, instead of deriving from the sparse tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: also sparse naming might be a bit vague, renaming to sparse_data or something might be better

from torchao.dtypes.floatx import Float8Layout

res = (
isinstance(input_tensor, AffineQuantizedTensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI, I feel we could simplify this list of checks if don't need to support a variety of kernels that all use the same subclassed tensor, it only needs to be specific enough so we can dispatch to the correct kernel. maybe this can be simplified if we split the AQT in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants