diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 98005cfbc810..8e2a8515c662 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -873,11 +873,11 @@ def test_fp4_double_safe(self): @require_torch_version_greater("2.7.1") @require_bitsandbytes_version_greater("0.45.5") -class Bnb4BitCompileTests(QuantCompileTests): +class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): return PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", + quant_backend="bitsandbytes_4bit", quant_kwargs={ "load_in_4bit": True, "bnb_4bit_quant_type": "nf4", @@ -888,12 +888,7 @@ def quantization_config(self): def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True - super()._test_torch_compile(quantization_config=self.quantization_config) - - def test_torch_compile_with_cpu_offload(self): - super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + super().test_torch_compile() def test_torch_compile_with_group_offload_leaf(self): - super()._test_torch_compile_with_group_offload_leaf( - quantization_config=self.quantization_config, use_stream=True - ) + super()._test_torch_compile_with_group_offload_leaf(use_stream=True) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index f3bbc34e8b2c..64f56b02b0dd 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -838,7 +838,7 @@ def test_serialization_sharded(self): @require_torch_version_greater_equal("2.6.0") @require_bitsandbytes_version_greater("0.45.5") -class Bnb8BitCompileTests(QuantCompileTests): +class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): return PipelineQuantizationConfig( @@ -849,15 +849,11 @@ def quantization_config(self): def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True - super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) + super()._test_torch_compile(torch_dtype=torch.float16) def test_torch_compile_with_cpu_offload(self): - super()._test_torch_compile_with_cpu_offload( - quantization_config=self.quantization_config, torch_dtype=torch.float16 - ) + super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16) @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") def test_torch_compile_with_group_offload_leaf(self): - super()._test_torch_compile_with_group_offload_leaf( - quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True - ) + super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True) diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index fe56f890ee8c..ba41678eaa64 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -654,7 +654,7 @@ def get_dummy_inputs(self): @require_torch_version_greater("2.7.1") -class GGUFCompileTests(QuantCompileTests): +class GGUFCompileTests(QuantCompileTests, unittest.TestCase): torch_dtype = torch.bfloat16 gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" @@ -662,15 +662,6 @@ class GGUFCompileTests(QuantCompileTests): def quantization_config(self): return GGUFQuantizationConfig(compute_dtype=self.torch_dtype) - def test_torch_compile(self): - super()._test_torch_compile(quantization_config=self.quantization_config) - - def test_torch_compile_with_cpu_offload(self): - super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) - - def test_torch_compile_with_group_offload_leaf(self): - super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) - def _init_pipeline(self, *args, **kwargs): transformer = FluxTransformer2DModel.from_single_file( self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 99bb8980ef9f..cfe2339e2b56 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc -import unittest +import inspect import torch @@ -23,7 +23,7 @@ @require_torch_gpu @slow -class QuantCompileTests(unittest.TestCase): +class QuantCompileTests: @property def quantization_config(self): raise NotImplementedError( @@ -50,30 +50,26 @@ def _init_pipeline(self, quantization_config, torch_dtype): ) return pipe - def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): - pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") - # import to ensure fullgraph True + def _test_torch_compile(self, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda") + # `fullgraph=True` ensures no graph breaks pipe.transformer.compile(fullgraph=True) - for _ in range(2): - # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) - def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): - pipe = self._init_pipeline(quantization_config, torch_dtype) + def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe.enable_model_cpu_offload() pipe.transformer.compile() - for _ in range(2): - # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) - def _test_torch_compile_with_group_offload_leaf( - self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False - ): - torch._dynamo.config.cache_size_limit = 10000 + def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False): + torch._dynamo.config.cache_size_limit = 1000 - pipe = self._init_pipeline(quantization_config, torch_dtype) + pipe = self._init_pipeline(self.quantization_config, torch_dtype) group_offload_kwargs = { "onload_device": torch.device("cuda"), "offload_device": torch.device("cpu"), @@ -87,6 +83,17 @@ def _test_torch_compile_with_group_offload_leaf( if torch.device(component.device).type == "cpu": component.to("cuda") - for _ in range(2): - # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) + + def test_torch_compile(self): + self._test_torch_compile() + + def test_torch_compile_with_cpu_offload(self): + self._test_torch_compile_with_cpu_offload() + + def test_torch_compile_with_group_offload_leaf(self, use_stream=False): + for cls in inspect.getmro(self.__class__): + if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests: + return + self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index c4cfc8eb87fb..9d09fd2f1bab 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -630,7 +630,7 @@ def test_int_a16w8_cpu(self): @require_torchao_version_greater_or_equal("0.7.0") -class TorchAoCompileTest(QuantCompileTests): +class TorchAoCompileTest(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): return PipelineQuantizationConfig( @@ -639,17 +639,15 @@ def quantization_config(self): }, ) - def test_torch_compile(self): - super()._test_torch_compile(quantization_config=self.quantization_config) - @unittest.skip( "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " "when compiling." ) def test_torch_compile_with_cpu_offload(self): # RuntimeError: _apply(): Couldn't swap Linear.weight - super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + super().test_torch_compile_with_cpu_offload() + @parameterized.expand([False, True]) @unittest.skip( """ For `use_stream=False`: @@ -659,8 +657,7 @@ def test_torch_compile_with_cpu_offload(self): Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. """ ) - @parameterized.expand([False, True]) - def test_torch_compile_with_group_offload_leaf(self): + def test_torch_compile_with_group_offload_leaf(self, use_stream): # For use_stream=False: # If we run group offloading without compilation, we will see: # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. @@ -673,7 +670,7 @@ def test_torch_compile_with_group_offload_leaf(self): # For use_stream=True: # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=, types=(,), arg_types=(,), kwarg_types={} - super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) + super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream) # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners