Skip to content

Commit 3cb66e8

Browse files
committed
update
1 parent 2891f14 commit 3cb66e8

File tree

2 files changed

+121
-122
lines changed

2 files changed

+121
-122
lines changed

src/diffusers/models/attention.py

Lines changed: 119 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,10 @@ def fuse_qkv_projections(self):
107107
are fused. For cross-attention modules, key and value projection matrices are fused.
108108
109109
"""
110-
self.original_attn_processors = None
111-
112110
for _, attn_processor in self.attn_processors.items():
113111
if "Added" in str(attn_processor.__class__.__name__):
114112
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
115113

116-
self.original_attn_processors = self.attn_processors
117-
118114
for module in self.modules():
119115
if isinstance(module, AttentionModuleMixin):
120116
module.fuse_projections(fuse=True)
@@ -129,30 +125,58 @@ def unfuse_qkv_projections(self):
129125
</Tip>
130126
131127
"""
132-
if self.original_attn_processors is not None:
133-
self.set_attn_processor(self.original_attn_processors)
128+
for _, attn_processor in self.attn_processors.items():
129+
attn_processor.fused_projections = False
134130

135131

136132
class AttentionModuleMixin:
137-
"""
138-
A mixin class that provides common methods for attention modules.
133+
_default_processor_cls = None
134+
_available_processors = []
135+
fused_projections = False
139136

140-
This mixin adds functionality to set different attention processors, handle attention masks, compute attention
141-
scores, and manage projections.
142-
"""
137+
def set_processor(self, processor: "AttnProcessor") -> None:
138+
"""
139+
Set the attention processor to use.
143140
144-
# Default processor classes to be overridden by subclasses
145-
default_processor_cls = None
146-
_available_processors = []
141+
Args:
142+
processor (`AttnProcessor`):
143+
The attention processor to use.
144+
"""
145+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
146+
# pop `processor` from `self._modules`
147+
if (
148+
hasattr(self, "processor")
149+
and isinstance(self.processor, torch.nn.Module)
150+
and not isinstance(processor, torch.nn.Module)
151+
):
152+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
153+
self._modules.pop("processor")
147154

148-
fused_projections = False
149-
is_cross_attention = False
155+
self.processor = processor
156+
157+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
158+
"""
159+
Get the attention processor in use.
160+
161+
Args:
162+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
163+
Set to `True` to return the deprecated LoRA attention processor.
150164
151-
def _get_compatible_processor(self, backend):
152-
for processor_cls in self._available_processors:
153-
if backend in processor_cls.compatible_backends:
154-
processor = processor_cls()
155-
return processor
165+
Returns:
166+
"AttentionProcessor": The attention processor in use.
167+
"""
168+
if not return_deprecated_lora:
169+
return self.processor
170+
171+
def set_attention_backend(self, backend: str):
172+
from .attention_dispatch import AttentionBackendName
173+
174+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
175+
if backend not in available_backends:
176+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
177+
178+
backend = AttentionBackendName(backend.lower())
179+
self.processor._attention_backend = backend
156180

157181
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
158182
"""
@@ -161,14 +185,12 @@ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
161185
Args:
162186
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
163187
"""
164-
processor = self.default_processor_cls()
165188

166189
if use_npu_flash_attention:
167190
if not is_torch_npu_available():
168191
raise ImportError("torch_npu is not available")
169-
processor = self._get_compatible_processor("npu")
170192

171-
self.set_processor(processor)
193+
self.set_attention_backend("_native_npu")
172194

173195
def set_use_xla_flash_attention(
174196
self,
@@ -187,52 +209,85 @@ def set_use_xla_flash_attention(
187209
is_flux (`bool`, *optional*, defaults to `False`):
188210
Whether the model is a Flux model.
189211
"""
190-
processor = self.default_processor_cls()
191212
if use_xla_flash_attention:
192213
if not is_torch_xla_available():
193214
raise ImportError("torch_xla is not available")
194-
processor = self._get_compatible_processor("xla")
195215

196-
self.set_processor(processor)
216+
self.set_attention_backend("_native_xla")
197217

198-
@torch.no_grad()
199-
def fuse_projections(self, fuse=True):
218+
def set_use_memory_efficient_attention_xformers(
219+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
220+
) -> None:
200221
"""
201-
Fuse the query, key, and value projections into a single projection for efficiency.
222+
Set whether to use memory efficient attention from `xformers` or not.
202223
203224
Args:
204-
fuse (`bool`): Whether to fuse the projections or not.
225+
use_memory_efficient_attention_xformers (`bool`):
226+
Whether to use memory efficient attention from `xformers` or not.
227+
attention_op (`Callable`, *optional*):
228+
The attention operation to use. Defaults to `None` which uses the default attention operation from
229+
`xformers`.
205230
"""
206-
# Skip if already in desired state
207-
if getattr(self, "fused_projections", False) == fuse:
231+
if use_memory_efficient_attention_xformers:
232+
if not is_xformers_available():
233+
raise ModuleNotFoundError(
234+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
235+
name="xformers",
236+
)
237+
elif not torch.cuda.is_available():
238+
raise ValueError(
239+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
240+
" only available for GPU "
241+
)
242+
else:
243+
try:
244+
# Make sure we can run the memory efficient attention
245+
if xformers is not None:
246+
dtype = None
247+
if attention_op is not None:
248+
op_fw, op_bw = attention_op
249+
dtype, *_ = op_fw.SUPPORTED_DTYPES
250+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
251+
_ = xformers.ops.memory_efficient_attention(q, q, q)
252+
except Exception as e:
253+
raise e
254+
255+
self.set_attention_backend("xformers")
256+
257+
@torch.no_grad()
258+
def fuse_projections(self):
259+
"""
260+
Fuse the query, key, and value projections into a single projection for efficiency.
261+
"""
262+
# Skip if already fused
263+
if getattr(self, "fused_projections", False):
208264
return
209265

210266
device = self.to_q.weight.data.device
211267
dtype = self.to_q.weight.data.dtype
212268

213-
if not self.is_cross_attention:
214-
# Fuse self-attention projections
215-
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
216-
in_features = concatenated_weights.shape[1]
217-
out_features = concatenated_weights.shape[0]
218-
219-
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
220-
self.to_qkv.weight.copy_(concatenated_weights)
221-
if self.use_bias:
222-
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
223-
self.to_qkv.bias.copy_(concatenated_bias)
224-
225-
else:
269+
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
226270
# Fuse cross-attention key-value projections
227271
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
228272
in_features = concatenated_weights.shape[1]
229273
out_features = concatenated_weights.shape[0]
230274

231275
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
232276
self.to_kv.weight.copy_(concatenated_weights)
233-
if self.use_bias:
277+
if hasattr(self, "use_bias") and self.use_bias:
234278
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
235279
self.to_kv.bias.copy_(concatenated_bias)
280+
else:
281+
# Fuse self-attention projections
282+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
283+
in_features = concatenated_weights.shape[1]
284+
out_features = concatenated_weights.shape[0]
285+
286+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
287+
self.to_qkv.weight.copy_(concatenated_weights)
288+
if hasattr(self, "use_bias") and self.use_bias:
289+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
290+
self.to_qkv.bias.copy_(concatenated_bias)
236291

237292
# Handle added projections for models like SD3, Flux, etc.
238293
if (
@@ -256,52 +311,28 @@ def fuse_projections(self, fuse=True):
256311
)
257312
self.to_added_qkv.bias.copy_(concatenated_bias)
258313

259-
self.fused_projections = fuse
314+
self.fused_projections = True
260315

261-
def set_use_memory_efficient_attention_xformers(
262-
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
263-
) -> None:
316+
@torch.no_grad()
317+
def unfuse_projections(self):
264318
"""
265-
Set whether to use memory efficient attention from `xformers` or not.
266-
267-
Args:
268-
use_memory_efficient_attention_xformers (`bool`):
269-
Whether to use memory efficient attention from `xformers` or not.
270-
attention_op (`Callable`, *optional*):
271-
The attention operation to use. Defaults to `None` which uses the default attention operation from
272-
`xformers`.
319+
Unfuse the query, key, and value projections back to separate projections.
273320
"""
274-
if use_memory_efficient_attention_xformers:
275-
if not is_xformers_available():
276-
raise ModuleNotFoundError(
277-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
278-
name="xformers",
279-
)
280-
elif not torch.cuda.is_available():
281-
raise ValueError(
282-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
283-
" only available for GPU "
284-
)
285-
else:
286-
try:
287-
# Make sure we can run the memory efficient attention
288-
if xformers is not None:
289-
dtype = None
290-
if attention_op is not None:
291-
op_fw, op_bw = attention_op
292-
dtype, *_ = op_fw.SUPPORTED_DTYPES
293-
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
294-
_ = xformers.ops.memory_efficient_attention(q, q, q)
295-
except Exception as e:
296-
raise e
321+
# Skip if not fused
322+
if not getattr(self, "fused_projections", False):
323+
return
297324

298-
processor = self._get_compatible_processor("xformers")
299-
else:
300-
# Set default processor
301-
processor = self.default_processor_cls()
325+
# Remove fused projection layers
326+
if hasattr(self, "to_qkv"):
327+
delattr(self, "to_qkv")
328+
329+
if hasattr(self, "to_kv"):
330+
delattr(self, "to_kv")
302331

303-
if processor is not None:
304-
self.set_processor(processor)
332+
if hasattr(self, "to_added_qkv"):
333+
delattr(self, "to_added_qkv")
334+
335+
self.fused_projections = False
305336

306337
def set_attention_slice(self, slice_size: int) -> None:
307338
"""
@@ -326,40 +357,6 @@ def set_attention_slice(self, slice_size: int) -> None:
326357

327358
self.set_processor(processor)
328359

329-
def set_processor(self, processor: "AttnProcessor") -> None:
330-
"""
331-
Set the attention processor to use.
332-
333-
Args:
334-
processor (`AttnProcessor`):
335-
The attention processor to use.
336-
"""
337-
# if current processor is in `self._modules` and if passed `processor` is not, we need to
338-
# pop `processor` from `self._modules`
339-
if (
340-
hasattr(self, "processor")
341-
and isinstance(self.processor, torch.nn.Module)
342-
and not isinstance(processor, torch.nn.Module)
343-
):
344-
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
345-
self._modules.pop("processor")
346-
347-
self.processor = processor
348-
349-
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
350-
"""
351-
Get the attention processor in use.
352-
353-
Args:
354-
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
355-
Set to `True` to return the deprecated LoRA attention processor.
356-
357-
Returns:
358-
"AttentionProcessor": The attention processor in use.
359-
"""
360-
if not return_deprecated_lora:
361-
return self.processor
362-
363360
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
364361
"""
365362
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def __init__(
366366
cross_attention_dim=None,
367367
dim_head=attention_head_dim,
368368
heads=num_attention_heads,
369+
qk_norm=qk_norm,
370+
eps=eps,
369371
dropout=0.0,
370372
bias=True,
371373
added_kv_proj_dim=dim,

0 commit comments

Comments
 (0)