-
Notifications
You must be signed in to change notification settings - Fork 274
Add activation sparsity (24 + fp8 dynamic quant) subclass #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: This PR adds in a kernel to accelerate $$xW^T$$ when x is sparse and we are memory bound. The idea here is that we can avoid loading the columns of $W$ that correspond to the zero elements of $x$. This lets us accelerate activation sparsity for bs=1 decode use cases. Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2213
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit 61aedfd with merge base 1017c7e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
) | ||
|
||
aten = torch.ops.aten | ||
|
||
|
||
def _pad_dense_input(dense_input: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ideally we don't do padding at this level, since this interferes with aliasing semantics, and make some op implementation impossible, like slice. is it possible to move this to the kernel itself? or improve the kernel to be able to handle the inputs without padding?
@@ -66,11 +97,12 @@ def __new__( | |||
) | |||
kwargs["dtype"] = sparse.dtype | |||
kwargs["requires_grad"] = False | |||
shape = (sparse.shape[0], 2 * sparse.shape[-1]) | |||
# shape = (sparse.shape[0], 2 * sparse.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
@@ -80,6 +112,7 @@ def __init__( | |||
self.meta = meta | |||
self.scale = scale | |||
self._layout = _layout | |||
self._shape = shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need a separate shape, instead of deriving from the sparse
tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: also sparse naming might be a bit vague, renaming to sparse_data
or something might be better
from torchao.dtypes.floatx import Float8Layout | ||
|
||
res = ( | ||
isinstance(input_tensor, AffineQuantizedTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just FYI, I feel we could simplify this list of checks if don't need to support a variety of kernels that all use the same subclassed tensor, it only needs to be specific enough so we can dispatch to the correct kernel. maybe this can be simplified if we split the AQT in the future.
Summary:
This PR adds in the following things:
Float8DynamicSemiSparseActivationFloat8WeightConfig
, that works withLinearActivationQuantizedTensor
.CutlassSemiSparseLayout
to support both weight and activation sparsity and adds a new impl + check.Test Plan:
pytest test/sparsity/test_activation.py
Reviewers:
Subscribers:
Tasks:
Tags: