diff --git a/test/dtypes/test_fbgemm_quantized.py b/test/dtypes/test_fbgemm_quantized.py new file mode 100644 index 0000000000..fe2573530c --- /dev/null +++ b/test/dtypes/test_fbgemm_quantized.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + FbgemmConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import is_sm_at_least_90 + + +class TestFbgemmInt4Tensor(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + def test_linear(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=(1, 128), + ) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dtypes/test_fbgemm_quantized_tensor.py b/test/dtypes/test_fbgemm_quantized_tensor.py new file mode 100644 index 0000000000..51b68dd977 --- /dev/null +++ b/test/dtypes/test_fbgemm_quantized_tensor.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + FbgemmConfig, + quantize_, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_90, +) + + +class TestFbgemmInt4Tensor(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch >= 2.6") + def test_linear(self): + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + ) + quantize_(linear, config) + quantized = linear(input) + self.assertTrue(compute_error(original, quantized) > 20) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py index 3b0a10e915..71cf8e144d 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/quantization/test_config_serialization.py @@ -20,6 +20,7 @@ config_to_dict, ) from torchao.quantization.quant_api import ( + FbgemmConfig, Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, @@ -34,11 +35,13 @@ UIntXWeightOnlyConfig, ) from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 # Define test configurations as fixtures configs = [ Float8DynamicActivationFloat8WeightConfig(), Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]), Float8WeightOnlyConfig( weight_dtype=torch.float8_e4m3fn, ), @@ -78,6 +81,9 @@ ), ] +if TORCH_VERSION_AT_LEAST_2_6: + configs += [FbgemmConfig(torch.bfloat16, torch.int4, torch.bfloat16, [1, 1, 256])] + # Create ids for better test naming def get_config_ids(configs): diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index dc03204b46..c17de52028 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -439,6 +439,17 @@ def ffn_or_attn_only(mod, fqn): f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" ) quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + elif "fbgemm" in quantization: + from torchao.quantization import FbgemmConfig + + _, precision, group_size = quantization.split("-") + group_size = int(group_size) + if precision == "int4": + quantize_(model, FbgemmConfig("bf16i4bf16", group_size)) + else: + raise NotImplementedError( + f"FbegemmConfig({precision=}) not supported yet" + ) elif "int4dq-" in quantization: from torchao.dtypes import CutlassInt4PackedLayout @@ -1163,7 +1174,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-" + + "embed-int8wo, marlin_qqq, gemlite---, float8dq, int4dq-, fbgemm-int4-" ), ) parser.add_argument( diff --git a/torchao/core/config.py b/torchao/core/config.py index d2d49981c9..3451b90c59 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -132,6 +132,12 @@ def default(self, o): if isinstance(o, list): return [self.encode_value(item) for item in o] + elif isinstance(o, tuple): + raise NotImplementedError( + "Tuples will be serialized as List in JSON, so we recommend to use " + f"Lists instead to avoid surprises. got: {o}" + ) + if isinstance(o, dict): return {k: self.encode_value(v) for k, v in o.items()} @@ -250,13 +256,18 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: # Recursively handle nested configs processed_data[key] = config_from_dict(value) elif isinstance(value, list): - # Handle lists of possible configs + # Handle lists or tuples of possible configs processed_data[key] = [ config_from_dict(item) if isinstance(item, dict) and "_type" in item and "_data" in item else item for item in value ] + elif isinstance(value, tuple): + raise NotImplementedError( + "Tuples will be serialized as List in JSON, so we recommend to use " + f"Lists instead to avoid surprises. got: {value}" + ) elif isinstance(value, dict): # Handle dicts of possible configs processed_data[key] = { diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index eb253c11bc..1003491828 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,6 +8,7 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) +from .fbgemm_quantized_tensor import to_fbgemm_quantized from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -61,4 +62,5 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", + "to_fbgemm_quantized", ] diff --git a/torchao/dtypes/fbgemm_quantized_tensor.py b/torchao/dtypes/fbgemm_quantized_tensor.py new file mode 100644 index 0000000000..fd788a73a3 --- /dev/null +++ b/torchao/dtypes/fbgemm_quantized_tensor.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib.util +from typing import List + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.utils import TorchAOBaseTensor + +__all__ = [ + "to_fbgemm_quantized", +] + +aten = torch.ops.aten + + +if importlib.util.find_spec("fbgemm_gpu") is None: + int4_row_quantize_zp = None + pack_int4 = None +else: + from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 + + +class FbgemmInt4Tensor(TorchAOBaseTensor): + tensor_data_attrs = ["packed_weight", "scale", "zero_point"] + tensor_attributes = ["group_size"] + + def __new__(cls, packed_weight, scale, zero_point, group_size): + shape = packed_weight.shape + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, packed_weight, scale, zero_point, group_size): + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + + def __tensor_flatten__(self): + return self.tensor_data_attrs, [ + getattr(self, attr) for attr in self.tensor_attributes + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensor_attributes, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *[getattr(self, attr) for attr in self.tensor_attributes], + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + @classmethod + def from_float( + cls, + w: torch.Tensor, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + output_dtype: torch.dtype, + block_size: List[int], + ): + assert len(block_size) == w.ndim, ( + f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" + ) + group_size = block_size[-1] + + assert (input_dtype, weight_dtype, output_dtype) == ( + torch.bfloat16, + torch.int4, + torch.bfloat16, + ) + + if w.ndim >= 3: + wq, scale, zero_point = zip( + *[int4_row_quantize_zp(i, group_size) for i in w], strict=False + ) + wq = torch.stack([pack_int4(i) for i in wq], dim=0) + scale = torch.stack(scale, dim=0) + zero_point = torch.stack(zero_point, dim=0) + else: + wq, scale, zero_point = int4_row_quantize_zp(w, group_size) + wq = pack_int4(wq) + + scale = scale.to(w.dtype) + zero_point = zero_point.to(w.dtype) + + del w + return FbgemmInt4Tensor( + packed_weight=wq, + scale=scale, + zero_point=zero_point, + group_size=group_size, + ) + + +implements = FbgemmInt4Tensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) + + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + res = torch.ops.fbgemm.bf16i4bf16_rowwise( + input_tensor, + weight_tensor.packed_weight, + weight_tensor.scale, + weight_tensor.zero_point, + ) + if bias is not None: + res = res + bias + return res.reshape(*orig_act_size[:-1], orig_out_features) + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements([aten.clone.default, aten.copy_.default]) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +# We can have `to_fbgemm_tensor` to dispatch to different Fbgemm tensors later +to_fbgemm_quantized = FbgemmInt4Tensor.from_float diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b4d46d8263..73ccd2e0ff 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -40,6 +40,7 @@ ) from .quant_api import ( CutlassInt4PackedLayout, + FbgemmConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -148,6 +149,7 @@ "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", "ModuleFqnToConfig", + "FbgemmConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f2aca97782..ada19859bc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -15,11 +15,12 @@ and mixed GEMM kernels """ +import importlib.util import logging import types import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -45,6 +46,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_fbgemm_quantized, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -142,6 +144,7 @@ "Int8DynActInt4WeightGPTQQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", + "FbgemmConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -1525,9 +1528,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype - granularity: Optional[ - Union[FP8Granularity, Tuple[FP8Granularity, FP8Granularity]] - ] = None + granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True @@ -1538,7 +1539,7 @@ def __post_init__(self): activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) - self.granularity = (activation_granularity, weight_granularity) + self.granularity = [activation_granularity, weight_granularity] # for bc @@ -1967,6 +1968,58 @@ def _fpx_weight_only_transform( return module +@dataclass +class FbgemmConfig(AOBaseConfig): + """Quantization Config for fbgemm-genai kernels + Args: + input_dtype (torch.dtype): input dtype of the kernel + weight_dtype (torch.dtype): weight dtype of the kernel + output_dtype (torch.dtype): output dtype of the kernel + group_size (int): The group size for weight + """ + + input_dtype: torch.dtype + weight_dtype: torch.dtype + output_dtype: torch.dtype + block_size: List[int] + + +@register_quantize_module_handler(FbgemmConfig) +def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: + # TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when + # https://github.com/pytorch/FBGEMM/issues/4198 is fixed + if importlib.util.find_spec("fbgemm_gpu") is None: + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + import fbgemm_gpu.experimental.gen_ai # noqa: F401 + + if fbgemm_gpu.__version__ < "1.2.0": + raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + + _SUPPORTED_DTYPES = { + (torch.bfloat16, torch.int4, torch.bfloat16), + } + + if ( + config.input_dtype, + config.weight_dtype, + config.output_dtype, + ) in _SUPPORTED_DTYPES: + weight = to_fbgemm_quantized( + module.weight, + config.input_dtype, + config.weight_dtype, + config.output_dtype, + config.block_size, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + raise NotImplementedError( + f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}" + ) + + @dataclass class ModuleFqnToConfig(AOBaseConfig): """Per module configurations for torchao quantize_ API diff --git a/torchao/utils.py b/torchao/utils.py index 280da4e632..1fa395cb8a 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import functools +import importlib import itertools import re import time @@ -40,6 +41,7 @@ "is_MI300", "is_sm_at_least_89", "is_sm_at_least_90", + "is_package_at_least", ] @@ -694,3 +696,11 @@ def check_xpu_version(device, version="2.8.0"): TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev") + + +def is_package_at_least(package_name: str, min_version: str): + package_exists = importlib.util.find_spec(package_name) is not None + if not package_exists: + return False + + return version(package_name) >= min_version