|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -import inspect |
16 |
| -from typing import Dict, List, Optional, Union |
17 | 15 |
|
18 |
| -from ..utils import is_transformers_available, logging |
19 | 16 | from .auto import DiffusersAutoQuantizer
|
20 | 17 | from .base import DiffusersQuantizer
|
21 |
| -from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin |
22 |
| - |
23 |
| - |
24 |
| -try: |
25 |
| - from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin |
26 |
| -except ImportError: |
27 |
| - |
28 |
| - class TransformersQuantConfigMixin: |
29 |
| - pass |
30 |
| - |
31 |
| - |
32 |
| -logger = logging.get_logger(__name__) |
33 |
| - |
34 |
| - |
35 |
| -class PipelineQuantizationConfig: |
36 |
| - """ |
37 |
| - Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. |
38 |
| -
|
39 |
| - Args: |
40 |
| - quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend |
41 |
| - is available to both `diffusers` and `transformers`. |
42 |
| - quant_kwargs (`dict`): Params to initialize the quantization backend class. |
43 |
| - components_to_quantize (`list`): Components of a pipeline to be quantized. |
44 |
| - quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline |
45 |
| - components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, |
46 |
| - and `components_to_quantize`. |
47 |
| - """ |
48 |
| - |
49 |
| - def __init__( |
50 |
| - self, |
51 |
| - quant_backend: str = None, |
52 |
| - quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, |
53 |
| - components_to_quantize: Optional[List[str]] = None, |
54 |
| - quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, |
55 |
| - ): |
56 |
| - self.quant_backend = quant_backend |
57 |
| - # Initialize kwargs to be {} to set to the defaults. |
58 |
| - self.quant_kwargs = quant_kwargs or {} |
59 |
| - self.components_to_quantize = components_to_quantize |
60 |
| - self.quant_mapping = quant_mapping |
61 |
| - |
62 |
| - self.post_init() |
63 |
| - |
64 |
| - def post_init(self): |
65 |
| - quant_mapping = self.quant_mapping |
66 |
| - self.is_granular = True if quant_mapping is not None else False |
67 |
| - |
68 |
| - self._validate_init_args() |
69 |
| - |
70 |
| - def _validate_init_args(self): |
71 |
| - if self.quant_backend and self.quant_mapping: |
72 |
| - raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") |
73 |
| - |
74 |
| - if not self.quant_mapping and not self.quant_backend: |
75 |
| - raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") |
76 |
| - |
77 |
| - if not self.quant_kwargs and not self.quant_mapping: |
78 |
| - raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") |
79 |
| - |
80 |
| - if self.quant_backend is not None: |
81 |
| - self._validate_init_kwargs_in_backends() |
82 |
| - |
83 |
| - if self.quant_mapping is not None: |
84 |
| - self._validate_quant_mapping_args() |
85 |
| - |
86 |
| - def _validate_init_kwargs_in_backends(self): |
87 |
| - quant_backend = self.quant_backend |
88 |
| - |
89 |
| - self._check_backend_availability(quant_backend) |
90 |
| - |
91 |
| - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
92 |
| - |
93 |
| - if quant_config_mapping_transformers is not None: |
94 |
| - init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) |
95 |
| - init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} |
96 |
| - else: |
97 |
| - init_kwargs_transformers = None |
98 |
| - |
99 |
| - init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) |
100 |
| - init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} |
101 |
| - |
102 |
| - if init_kwargs_transformers != init_kwargs_diffusers: |
103 |
| - raise ValueError( |
104 |
| - "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " |
105 |
| - f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " |
106 |
| - "this mapping would look like." |
107 |
| - ) |
108 |
| - |
109 |
| - def _validate_quant_mapping_args(self): |
110 |
| - quant_mapping = self.quant_mapping |
111 |
| - transformers_map, diffusers_map = self._get_quant_config_list() |
112 |
| - |
113 |
| - available_transformers = list(transformers_map.values()) if transformers_map else None |
114 |
| - available_diffusers = list(diffusers_map.values()) |
115 |
| - |
116 |
| - for module_name, config in quant_mapping.items(): |
117 |
| - if any(isinstance(config, cfg) for cfg in available_diffusers): |
118 |
| - continue |
119 |
| - |
120 |
| - if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): |
121 |
| - continue |
122 |
| - |
123 |
| - if available_transformers: |
124 |
| - raise ValueError( |
125 |
| - f"Provided config for module_name={module_name} could not be found. " |
126 |
| - f"Available diffusers configs: {available_diffusers}; " |
127 |
| - f"Available transformers configs: {available_transformers}." |
128 |
| - ) |
129 |
| - else: |
130 |
| - raise ValueError( |
131 |
| - f"Provided config for module_name={module_name} could not be found. " |
132 |
| - f"Available diffusers configs: {available_diffusers}." |
133 |
| - ) |
134 |
| - |
135 |
| - def _check_backend_availability(self, quant_backend: str): |
136 |
| - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
137 |
| - |
138 |
| - available_backends_transformers = ( |
139 |
| - list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None |
140 |
| - ) |
141 |
| - available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) |
142 |
| - |
143 |
| - if ( |
144 |
| - available_backends_transformers and quant_backend not in available_backends_transformers |
145 |
| - ) or quant_backend not in quant_config_mapping_diffusers: |
146 |
| - error_message = f"Provided quant_backend={quant_backend} was not found." |
147 |
| - if available_backends_transformers: |
148 |
| - error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." |
149 |
| - error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." |
150 |
| - raise ValueError(error_message) |
151 |
| - |
152 |
| - def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): |
153 |
| - quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() |
154 |
| - |
155 |
| - quant_mapping = self.quant_mapping |
156 |
| - components_to_quantize = self.components_to_quantize |
157 |
| - |
158 |
| - # Granular case |
159 |
| - if self.is_granular and module_name in quant_mapping: |
160 |
| - logger.debug(f"Initializing quantization config class for {module_name}.") |
161 |
| - config = quant_mapping[module_name] |
162 |
| - return config |
163 |
| - |
164 |
| - # Global config case |
165 |
| - else: |
166 |
| - should_quantize = False |
167 |
| - # Only quantize the modules requested for. |
168 |
| - if components_to_quantize and module_name in components_to_quantize: |
169 |
| - should_quantize = True |
170 |
| - # No specification for `components_to_quantize` means all modules should be quantized. |
171 |
| - elif not self.is_granular and not components_to_quantize: |
172 |
| - should_quantize = True |
173 |
| - |
174 |
| - if should_quantize: |
175 |
| - logger.debug(f"Initializing quantization config class for {module_name}.") |
176 |
| - mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers |
177 |
| - quant_config_cls = mapping_to_use[self.quant_backend] |
178 |
| - quant_kwargs = self.quant_kwargs |
179 |
| - return quant_config_cls(**quant_kwargs) |
180 |
| - |
181 |
| - # Fallback: no applicable configuration found. |
182 |
| - return None |
183 |
| - |
184 |
| - def _get_quant_config_list(self): |
185 |
| - if is_transformers_available(): |
186 |
| - from transformers.quantizers.auto import ( |
187 |
| - AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, |
188 |
| - ) |
189 |
| - else: |
190 |
| - quant_config_mapping_transformers = None |
191 |
| - |
192 |
| - from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers |
193 |
| - |
194 |
| - return quant_config_mapping_transformers, quant_config_mapping_diffusers |
| 18 | +from .pipe_quant_config import PipelineQuantizationConfig |
0 commit comments