Skip to content

[refactor] Flux single file implementation #11881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Dict, Optional, Tuple, Union

import torch
Expand Down
32 changes: 11 additions & 21 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,31 +1244,21 @@ class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think here you can patch __new__ with the correct import and return the new object. Similar to

class FluxSingleAttnProcessor2_0:

from .transformers.transformer_flux import FluxPosEmbed as FluxPosEmbed_

deprecate(
"FluxPosEmbed",
"1.0.0",
"Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please use `FluxPosEmbed` from `diffusers.models.transformers.transformer_flux` instead.",
)

self.theta = theta
self.axes_dim = axes_dim
self._rope = FluxPosEmbed_(theta, axes_dim)

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
return self._rope(ids)


class TimestepEmbedding(nn.Module):
Expand Down
44 changes: 44 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,50 @@ def enable_group_offload(
low_cpu_mem_usage=low_cpu_mem_usage,
)

def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.

Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName

backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)

for module in self.modules():
if not isinstance(module, AttentionModuleMixin):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend

def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention_processor import Attention, MochiAttention

attention_classes = (Attention, MochiAttention)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
160 changes: 119 additions & 41 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
Expand Down Expand Up @@ -73,7 +78,6 @@ def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
"""Public method to get projections based on whether we're using fused mode or not."""
if attn.is_fused and hasattr(attn, "to_qkv"):
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)

return self._get_projections(attn, hidden_states, encoder_hidden_states)

def __call__(
Expand Down Expand Up @@ -117,17 +121,10 @@ def __call__(
value = torch.cat([encoder_value, value], dim=2)

if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
)
hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down Expand Up @@ -242,12 +239,10 @@ def __call__(
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -292,13 +287,76 @@ def __call__(


@maybe_allow_in_graph
class FluxAttention(Attention):
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]

def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
assert qk_norm == "rms_norm", "Flux uses RMSNorm"

self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_proj_bias = added_proj_bias

self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)

if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))

if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)

if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)


@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
Expand Down Expand Up @@ -330,20 +388,19 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
joint_attention_kwargs = joint_attention_kwargs or {}

norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_out = self.proj_out(attn_mlp_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * proj_out

hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)

Expand Down Expand Up @@ -386,12 +443,13 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
joint_attention_kwargs = joint_attention_kwargs or {}

norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}

# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
Expand All @@ -410,7 +468,7 @@ def forward(
hidden_states = hidden_states + attn_output

norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)

ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
Expand All @@ -420,21 +478,54 @@ def forward(
hidden_states = hidden_states + ip_attn_output

# Process attention outputs for the `encoder_hidden_states`.

context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output

norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)

context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output

if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

return encoder_hidden_states, hidden_states


class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin


class FluxTransformer2DModel(
ModelMixin,
ConfigMixin,
Expand Down Expand Up @@ -537,10 +628,6 @@ def __init__(

self.gradient_checkpointing = False

# Using inherited methods from AttentionMixin

# Using inherited methods from AttentionMixin

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -634,11 +721,7 @@ def forward(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)

else:
Expand All @@ -665,12 +748,7 @@ def forward(

for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
)
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb, image_rotary_emb)

else:
hidden_states = block(
Expand Down
Loading
Loading