Skip to content

Commit 7a935a0

Browse files
authored
[tests] Unify compilation + offloading tests in quantization (#11910)
* unify the quant compile + offloading tests. * fix * update
1 parent 941b7fc commit 7a935a0

File tree

5 files changed

+42
-56
lines changed

5 files changed

+42
-56
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,11 @@ def test_fp4_double_safe(self):
873873

874874
@require_torch_version_greater("2.7.1")
875875
@require_bitsandbytes_version_greater("0.45.5")
876-
class Bnb4BitCompileTests(QuantCompileTests):
876+
class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
877877
@property
878878
def quantization_config(self):
879879
return PipelineQuantizationConfig(
880-
quant_backend="bitsandbytes_8bit",
880+
quant_backend="bitsandbytes_4bit",
881881
quant_kwargs={
882882
"load_in_4bit": True,
883883
"bnb_4bit_quant_type": "nf4",
@@ -888,12 +888,7 @@ def quantization_config(self):
888888

889889
def test_torch_compile(self):
890890
torch._dynamo.config.capture_dynamic_output_shape_ops = True
891-
super()._test_torch_compile(quantization_config=self.quantization_config)
892-
893-
def test_torch_compile_with_cpu_offload(self):
894-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
891+
super().test_torch_compile()
895892

896893
def test_torch_compile_with_group_offload_leaf(self):
897-
super()._test_torch_compile_with_group_offload_leaf(
898-
quantization_config=self.quantization_config, use_stream=True
899-
)
894+
super()._test_torch_compile_with_group_offload_leaf(use_stream=True)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def test_serialization_sharded(self):
838838

839839
@require_torch_version_greater_equal("2.6.0")
840840
@require_bitsandbytes_version_greater("0.45.5")
841-
class Bnb8BitCompileTests(QuantCompileTests):
841+
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
842842
@property
843843
def quantization_config(self):
844844
return PipelineQuantizationConfig(
@@ -849,15 +849,11 @@ def quantization_config(self):
849849

850850
def test_torch_compile(self):
851851
torch._dynamo.config.capture_dynamic_output_shape_ops = True
852-
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
852+
super()._test_torch_compile(torch_dtype=torch.float16)
853853

854854
def test_torch_compile_with_cpu_offload(self):
855-
super()._test_torch_compile_with_cpu_offload(
856-
quantization_config=self.quantization_config, torch_dtype=torch.float16
857-
)
855+
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
858856

859857
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
860858
def test_torch_compile_with_group_offload_leaf(self):
861-
super()._test_torch_compile_with_group_offload_leaf(
862-
quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True
863-
)
859+
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)

tests/quantization/gguf/test_gguf.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -654,23 +654,14 @@ def get_dummy_inputs(self):
654654

655655

656656
@require_torch_version_greater("2.7.1")
657-
class GGUFCompileTests(QuantCompileTests):
657+
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
658658
torch_dtype = torch.bfloat16
659659
gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
660660

661661
@property
662662
def quantization_config(self):
663663
return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
664664

665-
def test_torch_compile(self):
666-
super()._test_torch_compile(quantization_config=self.quantization_config)
667-
668-
def test_torch_compile_with_cpu_offload(self):
669-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
670-
671-
def test_torch_compile_with_group_offload_leaf(self):
672-
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
673-
674665
def _init_pipeline(self, *args, **kwargs):
675666
transformer = FluxTransformer2DModel.from_single_file(
676667
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype

tests/quantization/test_torch_compile_utils.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import gc
16-
import unittest
16+
import inspect
1717

1818
import torch
1919

@@ -23,7 +23,7 @@
2323

2424
@require_torch_gpu
2525
@slow
26-
class QuantCompileTests(unittest.TestCase):
26+
class QuantCompileTests:
2727
@property
2828
def quantization_config(self):
2929
raise NotImplementedError(
@@ -50,30 +50,26 @@ def _init_pipeline(self, quantization_config, torch_dtype):
5050
)
5151
return pipe
5252

53-
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
54-
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
55-
# import to ensure fullgraph True
53+
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
54+
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
55+
# `fullgraph=True` ensures no graph breaks
5656
pipe.transformer.compile(fullgraph=True)
5757

58-
for _ in range(2):
59-
# small resolutions to ensure speedy execution.
60-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
58+
# small resolutions to ensure speedy execution.
59+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
6160

62-
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
63-
pipe = self._init_pipeline(quantization_config, torch_dtype)
61+
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
62+
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
6463
pipe.enable_model_cpu_offload()
6564
pipe.transformer.compile()
6665

67-
for _ in range(2):
68-
# small resolutions to ensure speedy execution.
69-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
66+
# small resolutions to ensure speedy execution.
67+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
7068

71-
def _test_torch_compile_with_group_offload_leaf(
72-
self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False
73-
):
74-
torch._dynamo.config.cache_size_limit = 10000
69+
def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
70+
torch._dynamo.config.cache_size_limit = 1000
7571

76-
pipe = self._init_pipeline(quantization_config, torch_dtype)
72+
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
7773
group_offload_kwargs = {
7874
"onload_device": torch.device("cuda"),
7975
"offload_device": torch.device("cpu"),
@@ -87,6 +83,17 @@ def _test_torch_compile_with_group_offload_leaf(
8783
if torch.device(component.device).type == "cpu":
8884
component.to("cuda")
8985

90-
for _ in range(2):
91-
# small resolutions to ensure speedy execution.
92-
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
86+
# small resolutions to ensure speedy execution.
87+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
88+
89+
def test_torch_compile(self):
90+
self._test_torch_compile()
91+
92+
def test_torch_compile_with_cpu_offload(self):
93+
self._test_torch_compile_with_cpu_offload()
94+
95+
def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
96+
for cls in inspect.getmro(self.__class__):
97+
if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
98+
return
99+
self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)

tests/quantization/torchao/test_torchao.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def test_int_a16w8_cpu(self):
630630

631631

632632
@require_torchao_version_greater_or_equal("0.7.0")
633-
class TorchAoCompileTest(QuantCompileTests):
633+
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
634634
@property
635635
def quantization_config(self):
636636
return PipelineQuantizationConfig(
@@ -639,17 +639,15 @@ def quantization_config(self):
639639
},
640640
)
641641

642-
def test_torch_compile(self):
643-
super()._test_torch_compile(quantization_config=self.quantization_config)
644-
645642
@unittest.skip(
646643
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
647644
"when compiling."
648645
)
649646
def test_torch_compile_with_cpu_offload(self):
650647
# RuntimeError: _apply(): Couldn't swap Linear.weight
651-
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
648+
super().test_torch_compile_with_cpu_offload()
652649

650+
@parameterized.expand([False, True])
653651
@unittest.skip(
654652
"""
655653
For `use_stream=False`:
@@ -659,8 +657,7 @@ def test_torch_compile_with_cpu_offload(self):
659657
Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
660658
"""
661659
)
662-
@parameterized.expand([False, True])
663-
def test_torch_compile_with_group_offload_leaf(self):
660+
def test_torch_compile_with_group_offload_leaf(self, use_stream):
664661
# For use_stream=False:
665662
# If we run group offloading without compilation, we will see:
666663
# RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
@@ -673,7 +670,7 @@ def test_torch_compile_with_group_offload_leaf(self):
673670

674671
# For use_stream=True:
675672
# NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), arg_types=(<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>,), kwarg_types={}
676-
super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
673+
super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
677674

678675

679676
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners

0 commit comments

Comments
 (0)