Skip to content

[tests] Unify compilation + offloading tests in quantization #11910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,11 @@ def test_fp4_double_safe(self):

@require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
Expand All @@ -888,12 +888,7 @@ def quantization_config(self):

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile()

def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
12 changes: 4 additions & 8 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def test_serialization_sharded(self):

@require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
Expand All @@ -849,15 +849,11 @@ def quantization_config(self):

def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
super()._test_torch_compile(torch_dtype=torch.float16)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(
quantization_config=self.quantization_config, torch_dtype=torch.float16
)
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)

@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
)
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
11 changes: 1 addition & 10 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,23 +654,14 @@ def get_dummy_inputs(self):


@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"

@property
def quantization_config(self):
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)

def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)

def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)

def test_torch_compile_with_group_offload_leaf(self):
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)

def _init_pipeline(self, *args, **kwargs):
transformer = FluxTransformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
Expand Down
49 changes: 28 additions & 21 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import inspect

import torch

Expand All @@ -23,7 +23,7 @@

@require_torch_gpu
@slow
class QuantCompileTests(unittest.TestCase):
class QuantCompileTests:
@property
def quantization_config(self):
raise NotImplementedError(
Expand All @@ -50,30 +50,26 @@ def _init_pipeline(self, quantization_config, torch_dtype):
)
return pipe

def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
# import to ensure fullgraph True
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
# `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True)

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(quantization_config, torch_dtype)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_group_offload_leaf(
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
):
torch._dynamo.config.cache_size_limit = 10000
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
torch._dynamo.config.cache_size_limit = 1000

pipe = self._init_pipeline(quantization_config, torch_dtype)
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = {
"onload_device": torch.device("cuda"),
"offload_device": torch.device("cpu"),
Expand All @@ -87,6 +83,17 @@ def _test_torch_compile_with_group_offload_leaf(
if torch.device(component.device).type == "cpu":
component.to("cuda")

for _ in range(2):
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

def test_torch_compile(self):
self._test_torch_compile()

def test_torch_compile_with_cpu_offload(self):
self._test_torch_compile_with_cpu_offload()

def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
for cls in inspect.getmro(self.__class__):
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
return
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
13 changes: 5 additions & 8 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def test_int_a16w8_cpu(self):


@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
Expand All @@ -639,17 +639,15 @@ def quantization_config(self):
},
)

def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)

@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
# RuntimeError: _apply(): Couldn't swap Linear.weight
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
super().test_torch_compile_with_cpu_offload()

@parameterized.expand([False, True])
@unittest.skip(
"""
For `use_stream=False`:
Expand All @@ -659,8 +657,7 @@ def test_torch_compile_with_cpu_offload(self):
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
"""
)
@parameterized.expand([False, True])
def test_torch_compile_with_group_offload_leaf(self):
def test_torch_compile_with_group_offload_leaf(self, use_stream):
# For use_stream=False:
# If we run group offloading without compilation, we will see:
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
Expand All @@ -673,7 +670,7 @@ def test_torch_compile_with_group_offload_leaf(self):

# For use_stream=True:
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
Expand Down