-
Notifications
You must be signed in to change notification settings - Fork 383
Open
Labels
Description
I'm trying to quantize RTDETRv2's weights to fp8 to try out torchao with the following script:
import torch
from transformers import RTDetrV2ForObjectDetection
from torchao.quantization import quantize_, Float8WeightOnlyConfig
from torchao.float8.config import e4m3_dtype
resnet_backbone_variant = "r50vd"
pretrained_model = RTDetrV2ForObjectDetection.from_pretrained(
f"PekingU/rtdetr_v2_{resnet_backbone_variant}",
num_labels=80,
ignore_mismatched_sizes=False
).eval().to('cuda')
def _is_linear(m, fqn):
return isinstance(m, torch.nn.Linear)
cfg = Float8WeightOnlyConfig(weight_dtype=e4m3_dtype)
quantize_(pretrained_model, cfg, _is_linear)When running the script I get the following error inside torchao:
Traceback (most recent call last):
File "project_root/object-detect/src/runnables/debug.py", line 18, in <module>
quantize_(pretrained_model, cfg, _is_linear)
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 544, in quantize_
_replace_with_custom_fn_if_matches_filter(
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 220, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 220, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 215, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model, *extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1690, in _float8_weight_only_transform
new_weight = _float8_weight_only_quant_tensor(module.weight, config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1673, in _float8_weight_only_quant_tensor
new_weight = Float8Tensor.from_hp(
^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 226, in from_hp
scale = _choose_scale_float8(
^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 658, in _dispatch__torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_primitives.py", line 2313, in _choose_scale_float8
tensor_reshaped = tensor.view(shape_for_reduction)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 674, in _dispatch__torch_dispatch__
return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 491, in wrapper
return func(f, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_root/lib/python3.12/site-packages/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 580, in _
scale_shape.append(qdata.shape[i] // self.block_size[i])
~~~~~~~~~~~^^^
IndexError: tuple index out of range
The error happend on the class_embed submodule. Interestingly enough, the model's linear layers's weights are already Float8Tensors which are then passed to the from_hp method, so maybe the error happens when traversing the module structure recursively and somehow visiting the same module twice.
The previous root module, that contains the backbone and encoder-decoder, was quantized successfully
My torchao version is 0.14.1.