Skip to content

Unable to quantize HF's RTDETRv2 #3363

@Rafael-Cast

Description

@Rafael-Cast

I'm trying to quantize RTDETRv2's weights to fp8 to try out torchao with the following script:

import torch
from transformers import RTDetrV2ForObjectDetection
from torchao.quantization import quantize_, Float8WeightOnlyConfig
from torchao.float8.config import e4m3_dtype

resnet_backbone_variant = "r50vd"

pretrained_model = RTDetrV2ForObjectDetection.from_pretrained(
    f"PekingU/rtdetr_v2_{resnet_backbone_variant}",
    num_labels=80,
    ignore_mismatched_sizes=False
).eval().to('cuda')

def _is_linear(m, fqn):
    return isinstance(m, torch.nn.Linear)

cfg = Float8WeightOnlyConfig(weight_dtype=e4m3_dtype)
quantize_(pretrained_model, cfg, _is_linear)

When running the script I get the following error inside torchao:

Traceback (most recent call last):
  File "project_root/object-detect/src/runnables/debug.py", line 18, in <module>
    quantize_(pretrained_model, cfg, _is_linear)
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 544, in quantize_
    _replace_with_custom_fn_if_matches_filter(
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 220, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 220, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 215, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model, *extra_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1690, in _float8_weight_only_transform
    new_weight = _float8_weight_only_quant_tensor(module.weight, config)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_api.py", line 1673, in _float8_weight_only_quant_tensor
    new_weight = Float8Tensor.from_hp(
                 ^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 226, in from_hp
    scale = _choose_scale_float8(
            ^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 658, in _dispatch__torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quant_primitives.py", line 2313, in _choose_scale_float8
    tensor_reshaped = tensor.view(shape_for_reduction)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 674, in _dispatch__torch_dispatch__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/utils.py", line 491, in wrapper
    return func(f, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_root/lib/python3.12/site-packages/torchao/quantization/quantize_/workflows/float8/float8_tensor.py", line 580, in _
    scale_shape.append(qdata.shape[i] // self.block_size[i])
                       ~~~~~~~~~~~^^^
IndexError: tuple index out of range

The error happend on the class_embed submodule. Interestingly enough, the model's linear layers's weights are already Float8Tensors which are then passed to the from_hp method, so maybe the error happens when traversing the module structure recursively and somehow visiting the same module twice.
The previous root module, that contains the backbone and encoder-decoder, was quantized successfully

My torchao version is 0.14.1.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions