Skip to content

Add support for fbgemm int4 mm kernel #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/dtypes/test_fbgemm_quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.quantization import (
FbgemmConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import is_sm_at_least_90


class TestFbgemmInt4Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=(1, 128),
)
quantize_(linear, config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)


if __name__ == "__main__":
run_tests()
48 changes: 48 additions & 0 deletions test/dtypes/test_fbgemm_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao.quantization import (
FbgemmConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_90,
)


class TestFbgemmInt4Tensor(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6")
def test_linear(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
quantize_(linear, config)
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)


if __name__ == "__main__":
run_tests()
6 changes: 6 additions & 0 deletions test/quantization/test_config_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
config_to_dict,
)
from torchao.quantization.quant_api import (
FbgemmConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
Expand All @@ -34,11 +35,13 @@
UIntXWeightOnlyConfig,
)
from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

# Define test configurations as fixtures
configs = [
Float8DynamicActivationFloat8WeightConfig(),
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]),
Float8WeightOnlyConfig(
weight_dtype=torch.float8_e4m3fn,
),
Expand Down Expand Up @@ -78,6 +81,9 @@
),
]

if TORCH_VERSION_AT_LEAST_2_6:
configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])]


# Create ids for better test naming
def get_config_ids(configs):
Expand Down
13 changes: 12 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,17 @@ def ffn_or_attn_only(mod, fqn):
f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
)
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
elif "fbgemm" in quantization:
from torchao.quantization import FbgemmConfig

_, precision, group_size = quantization.split("-")
group_size = int(group_size)
if precision == "int4":
quantize_(model, FbgemmConfig("bf16i4bf16", group_size))
else:
raise NotImplementedError(
f"FbegemmConfig({precision=}) not supported yet"
)
elif "int4dq-" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout

Expand Down Expand Up @@ -1163,7 +1174,7 @@ def callback(x):
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>"
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>, fbgemm-int4-<group_size>"
),
)
parser.add_argument(
Expand Down
13 changes: 12 additions & 1 deletion torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def default(self, o):
if isinstance(o, list):
return [self.encode_value(item) for item in o]

elif isinstance(o, tuple):
raise NotImplementedError(
"Tuples will be serialized as List in JSON, so we recommend to use "
f"Lists instead to avoid surprises. got: {o}"
)

if isinstance(o, dict):
return {k: self.encode_value(v) for k, v in o.items()}

Expand Down Expand Up @@ -250,13 +256,18 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
# Recursively handle nested configs
processed_data[key] = config_from_dict(value)
elif isinstance(value, list):
# Handle lists of possible configs
# Handle lists or tuples of possible configs
processed_data[key] = [
config_from_dict(item)
if isinstance(item, dict) and "_type" in item and "_data" in item
else item
for item in value
]
elif isinstance(value, tuple):
raise NotImplementedError(
"Tuples will be serialized as List in JSON, so we recommend to use "
f"Lists instead to avoid surprises. got: {value}"
)
elif isinstance(value, dict):
# Handle dicts of possible configs
processed_data[key] = {
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .fbgemm_quantized_tensor import to_fbgemm_quantized
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -61,4 +62,5 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_quantized",
]
161 changes: 161 additions & 0 deletions torchao/dtypes/fbgemm_quantized_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


import importlib.util
from typing import List

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.utils import TorchAOBaseTensor

__all__ = [
"to_fbgemm_quantized",
]

aten = torch.ops.aten


if importlib.util.find_spec("fbgemm_gpu") is None:
int4_row_quantize_zp = None
pack_int4 = None
else:
from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4


class FbgemmInt4Tensor(TorchAOBaseTensor):
tensor_data_attrs = ["packed_weight", "scale", "zero_point"]
tensor_attributes = ["group_size"]

def __new__(cls, packed_weight, scale, zero_point, group_size):
shape = packed_weight.shape
kwargs = {}
kwargs["device"] = packed_weight.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, packed_weight, scale, zero_point, group_size):
self.packed_weight = packed_weight
self.scale = scale
self.zero_point = zero_point
self.group_size = group_size

def __tensor_flatten__(self):
return self.tensor_data_attrs, [
getattr(self, attr) for attr in self.tensor_attributes
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
return cls(
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
*tensor_attributes,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
*[fn(getattr(self, attr)) for attr in self.tensor_data_attrs],
*[getattr(self, attr) for attr in self.tensor_attributes],
)

def __repr__(self):
return (
f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, "
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

@classmethod
def from_float(
cls,
w: torch.Tensor,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
output_dtype: torch.dtype,
block_size: List[int],
):
assert len(block_size) == w.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
)
group_size = block_size[-1]

assert (input_dtype, weight_dtype, output_dtype) == (
torch.bfloat16,
torch.int4,
torch.bfloat16,
)

if w.ndim >= 3:
wq, scale, zero_point = zip(
*[int4_row_quantize_zp(i, group_size) for i in w], strict=False
)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
scale = torch.stack(scale, dim=0)
zero_point = torch.stack(zero_point, dim=0)
else:
wq, scale, zero_point = int4_row_quantize_zp(w, group_size)
wq = pack_int4(wq)

scale = scale.to(w.dtype)
zero_point = zero_point.to(w.dtype)

del w
return FbgemmInt4Tensor(
packed_weight=wq,
scale=scale,
zero_point=zero_point,
group_size=group_size,
)


implements = FbgemmInt4Tensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

res = torch.ops.fbgemm.bf16i4bf16_rowwise(
input_tensor,
weight_tensor.packed_weight,
weight_tensor.scale,
weight_tensor.zero_point,
)
if bias is not None:
res = res + bias
return res.reshape(*orig_act_size[:-1], orig_out_features)


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default, aten.copy_.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later
to_fbgemm_quantized = FbgemmInt4Tensor.from_float
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from .quant_api import (
CutlassInt4PackedLayout,
FbgemmConfig,
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
Expand Down Expand Up @@ -148,6 +149,7 @@
"FPXWeightOnlyConfig",
"GemliteUIntXWeightOnlyConfig",
"ModuleFqnToConfig",
"FbgemmConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
Loading
Loading