@@ -107,14 +107,10 @@ def fuse_qkv_projections(self):
107
107
are fused. For cross-attention modules, key and value projection matrices are fused.
108
108
109
109
"""
110
- self .original_attn_processors = None
111
-
112
110
for _ , attn_processor in self .attn_processors .items ():
113
111
if "Added" in str (attn_processor .__class__ .__name__ ):
114
112
raise ValueError ("`fuse_qkv_projections()` is not supported for models having added KV projections." )
115
113
116
- self .original_attn_processors = self .attn_processors
117
-
118
114
for module in self .modules ():
119
115
if isinstance (module , AttentionModuleMixin ):
120
116
module .fuse_projections (fuse = True )
@@ -129,30 +125,58 @@ def unfuse_qkv_projections(self):
129
125
</Tip>
130
126
131
127
"""
132
- if self . original_attn_processors is not None :
133
- self . set_attn_processor ( self . original_attn_processors )
128
+ for _ , attn_processor in self . attn_processors . items () :
129
+ attn_processor . fused_projections = False
134
130
135
131
136
132
class AttentionModuleMixin :
137
- """
138
- A mixin class that provides common methods for attention modules.
133
+ _default_processor_cls = None
134
+ _available_processors = []
135
+ fused_projections = False
139
136
140
- This mixin adds functionality to set different attention processors, handle attention masks, compute attention
141
- scores, and manage projections.
142
- """
137
+ def set_processor ( self , processor : "AttnProcessor" ) -> None :
138
+ """
139
+ Set the attention processor to use.
143
140
144
- # Default processor classes to be overridden by subclasses
145
- default_processor_cls = None
146
- _available_processors = []
141
+ Args:
142
+ processor (`AttnProcessor`):
143
+ The attention processor to use.
144
+ """
145
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
146
+ # pop `processor` from `self._modules`
147
+ if (
148
+ hasattr (self , "processor" )
149
+ and isinstance (self .processor , torch .nn .Module )
150
+ and not isinstance (processor , torch .nn .Module )
151
+ ):
152
+ logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
153
+ self ._modules .pop ("processor" )
147
154
148
- fused_projections = False
149
- is_cross_attention = False
155
+ self .processor = processor
156
+
157
+ def get_processor (self , return_deprecated_lora : bool = False ) -> "AttentionProcessor" :
158
+ """
159
+ Get the attention processor in use.
160
+
161
+ Args:
162
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
163
+ Set to `True` to return the deprecated LoRA attention processor.
150
164
151
- def _get_compatible_processor (self , backend ):
152
- for processor_cls in self ._available_processors :
153
- if backend in processor_cls .compatible_backends :
154
- processor = processor_cls ()
155
- return processor
165
+ Returns:
166
+ "AttentionProcessor": The attention processor in use.
167
+ """
168
+ if not return_deprecated_lora :
169
+ return self .processor
170
+
171
+ def set_attention_backend (self , backend : str ):
172
+ from .attention_dispatch import AttentionBackendName
173
+
174
+ available_backends = {x .value for x in AttentionBackendName .__members__ .values ()}
175
+ if backend not in available_backends :
176
+ raise ValueError (f"`{ backend = } ` must be one of the following: " + ", " .join (available_backends ))
177
+
178
+ backend = AttentionBackendName (backend .lower ())
179
+ self .processor ._attention_backend = backend
156
180
157
181
def set_use_npu_flash_attention (self , use_npu_flash_attention : bool ) -> None :
158
182
"""
@@ -161,14 +185,12 @@ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
161
185
Args:
162
186
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
163
187
"""
164
- processor = self .default_processor_cls ()
165
188
166
189
if use_npu_flash_attention :
167
190
if not is_torch_npu_available ():
168
191
raise ImportError ("torch_npu is not available" )
169
- processor = self ._get_compatible_processor ("npu" )
170
192
171
- self .set_processor ( processor )
193
+ self .set_attention_backend ( "_native_npu" )
172
194
173
195
def set_use_xla_flash_attention (
174
196
self ,
@@ -187,52 +209,85 @@ def set_use_xla_flash_attention(
187
209
is_flux (`bool`, *optional*, defaults to `False`):
188
210
Whether the model is a Flux model.
189
211
"""
190
- processor = self .default_processor_cls ()
191
212
if use_xla_flash_attention :
192
213
if not is_torch_xla_available ():
193
214
raise ImportError ("torch_xla is not available" )
194
- processor = self ._get_compatible_processor ("xla" )
195
215
196
- self .set_processor ( processor )
216
+ self .set_attention_backend ( "_native_xla" )
197
217
198
- @torch .no_grad ()
199
- def fuse_projections (self , fuse = True ):
218
+ def set_use_memory_efficient_attention_xformers (
219
+ self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
220
+ ) -> None :
200
221
"""
201
- Fuse the query, key, and value projections into a single projection for efficiency .
222
+ Set whether to use memory efficient attention from `xformers` or not .
202
223
203
224
Args:
204
- fuse (`bool`): Whether to fuse the projections or not.
225
+ use_memory_efficient_attention_xformers (`bool`):
226
+ Whether to use memory efficient attention from `xformers` or not.
227
+ attention_op (`Callable`, *optional*):
228
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
229
+ `xformers`.
205
230
"""
206
- # Skip if already in desired state
207
- if getattr (self , "fused_projections" , False ) == fuse :
231
+ if use_memory_efficient_attention_xformers :
232
+ if not is_xformers_available ():
233
+ raise ModuleNotFoundError (
234
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers" ,
235
+ name = "xformers" ,
236
+ )
237
+ elif not torch .cuda .is_available ():
238
+ raise ValueError (
239
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
240
+ " only available for GPU "
241
+ )
242
+ else :
243
+ try :
244
+ # Make sure we can run the memory efficient attention
245
+ if xformers is not None :
246
+ dtype = None
247
+ if attention_op is not None :
248
+ op_fw , op_bw = attention_op
249
+ dtype , * _ = op_fw .SUPPORTED_DTYPES
250
+ q = torch .randn ((1 , 2 , 40 ), device = "cuda" , dtype = dtype )
251
+ _ = xformers .ops .memory_efficient_attention (q , q , q )
252
+ except Exception as e :
253
+ raise e
254
+
255
+ self .set_attention_backend ("xformers" )
256
+
257
+ @torch .no_grad ()
258
+ def fuse_projections (self ):
259
+ """
260
+ Fuse the query, key, and value projections into a single projection for efficiency.
261
+ """
262
+ # Skip if already fused
263
+ if getattr (self , "fused_projections" , False ):
208
264
return
209
265
210
266
device = self .to_q .weight .data .device
211
267
dtype = self .to_q .weight .data .dtype
212
268
213
- if not self .is_cross_attention :
214
- # Fuse self-attention projections
215
- concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
216
- in_features = concatenated_weights .shape [1 ]
217
- out_features = concatenated_weights .shape [0 ]
218
-
219
- self .to_qkv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
220
- self .to_qkv .weight .copy_ (concatenated_weights )
221
- if self .use_bias :
222
- concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
223
- self .to_qkv .bias .copy_ (concatenated_bias )
224
-
225
- else :
269
+ if hasattr (self , "is_cross_attention" ) and self .is_cross_attention :
226
270
# Fuse cross-attention key-value projections
227
271
concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
228
272
in_features = concatenated_weights .shape [1 ]
229
273
out_features = concatenated_weights .shape [0 ]
230
274
231
275
self .to_kv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
232
276
self .to_kv .weight .copy_ (concatenated_weights )
233
- if self .use_bias :
277
+ if hasattr ( self , "use_bias" ) and self .use_bias :
234
278
concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
235
279
self .to_kv .bias .copy_ (concatenated_bias )
280
+ else :
281
+ # Fuse self-attention projections
282
+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
283
+ in_features = concatenated_weights .shape [1 ]
284
+ out_features = concatenated_weights .shape [0 ]
285
+
286
+ self .to_qkv = nn .Linear (in_features , out_features , bias = self .use_bias , device = device , dtype = dtype )
287
+ self .to_qkv .weight .copy_ (concatenated_weights )
288
+ if hasattr (self , "use_bias" ) and self .use_bias :
289
+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
290
+ self .to_qkv .bias .copy_ (concatenated_bias )
236
291
237
292
# Handle added projections for models like SD3, Flux, etc.
238
293
if (
@@ -256,52 +311,28 @@ def fuse_projections(self, fuse=True):
256
311
)
257
312
self .to_added_qkv .bias .copy_ (concatenated_bias )
258
313
259
- self .fused_projections = fuse
314
+ self .fused_projections = True
260
315
261
- def set_use_memory_efficient_attention_xformers (
262
- self , use_memory_efficient_attention_xformers : bool , attention_op : Optional [Callable ] = None
263
- ) -> None :
316
+ @torch .no_grad ()
317
+ def unfuse_projections (self ):
264
318
"""
265
- Set whether to use memory efficient attention from `xformers` or not.
266
-
267
- Args:
268
- use_memory_efficient_attention_xformers (`bool`):
269
- Whether to use memory efficient attention from `xformers` or not.
270
- attention_op (`Callable`, *optional*):
271
- The attention operation to use. Defaults to `None` which uses the default attention operation from
272
- `xformers`.
319
+ Unfuse the query, key, and value projections back to separate projections.
273
320
"""
274
- if use_memory_efficient_attention_xformers :
275
- if not is_xformers_available ():
276
- raise ModuleNotFoundError (
277
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers" ,
278
- name = "xformers" ,
279
- )
280
- elif not torch .cuda .is_available ():
281
- raise ValueError (
282
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
283
- " only available for GPU "
284
- )
285
- else :
286
- try :
287
- # Make sure we can run the memory efficient attention
288
- if xformers is not None :
289
- dtype = None
290
- if attention_op is not None :
291
- op_fw , op_bw = attention_op
292
- dtype , * _ = op_fw .SUPPORTED_DTYPES
293
- q = torch .randn ((1 , 2 , 40 ), device = "cuda" , dtype = dtype )
294
- _ = xformers .ops .memory_efficient_attention (q , q , q )
295
- except Exception as e :
296
- raise e
321
+ # Skip if not fused
322
+ if not getattr (self , "fused_projections" , False ):
323
+ return
297
324
298
- processor = self ._get_compatible_processor ("xformers" )
299
- else :
300
- # Set default processor
301
- processor = self .default_processor_cls ()
325
+ # Remove fused projection layers
326
+ if hasattr (self , "to_qkv" ):
327
+ delattr (self , "to_qkv" )
328
+
329
+ if hasattr (self , "to_kv" ):
330
+ delattr (self , "to_kv" )
302
331
303
- if processor is not None :
304
- self .set_processor (processor )
332
+ if hasattr (self , "to_added_qkv" ):
333
+ delattr (self , "to_added_qkv" )
334
+
335
+ self .fused_projections = False
305
336
306
337
def set_attention_slice (self , slice_size : int ) -> None :
307
338
"""
@@ -326,40 +357,6 @@ def set_attention_slice(self, slice_size: int) -> None:
326
357
327
358
self .set_processor (processor )
328
359
329
- def set_processor (self , processor : "AttnProcessor" ) -> None :
330
- """
331
- Set the attention processor to use.
332
-
333
- Args:
334
- processor (`AttnProcessor`):
335
- The attention processor to use.
336
- """
337
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
338
- # pop `processor` from `self._modules`
339
- if (
340
- hasattr (self , "processor" )
341
- and isinstance (self .processor , torch .nn .Module )
342
- and not isinstance (processor , torch .nn .Module )
343
- ):
344
- logger .info (f"You are removing possibly trained weights of { self .processor } with { processor } " )
345
- self ._modules .pop ("processor" )
346
-
347
- self .processor = processor
348
-
349
- def get_processor (self , return_deprecated_lora : bool = False ) -> "AttentionProcessor" :
350
- """
351
- Get the attention processor in use.
352
-
353
- Args:
354
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
355
- Set to `True` to return the deprecated LoRA attention processor.
356
-
357
- Returns:
358
- "AttentionProcessor": The attention processor in use.
359
- """
360
- if not return_deprecated_lora :
361
- return self .processor
362
-
363
360
def batch_to_head_dim (self , tensor : torch .Tensor ) -> torch .Tensor :
364
361
"""
365
362
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
0 commit comments