Skip to content

Chroma as a FLUX.1 variant #11566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

@@ -1327,7 +1327,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
@@ -1637,6 +1637,50 @@ def forward(self, timestep, guidance, pooled_projection):
return conditioning


class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()

self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.embedder = ChromaApproximator(
in_dim=factor * 4,
out_dim=out_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
)
self.embedding_dim = embedding_dim

self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
persistent=False,
)

def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
if guidance is not None:
guidance_proj = self.guidance_proj(guidance)
else:
guidance_proj = torch.zeros(
(self.embedding_dim, self.guidance_proj.num_channels),
dtype=timesteps_proj.dtype,
device=timesteps_proj.device,
)

mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
conditioning = self.embedder(input_vec)

return conditioning


class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
@@ -2230,6 +2274,25 @@ def forward(self, caption):
return hidden_states


class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)

def forward(self, x):
x = self.in_proj(x)

for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))

return self.out_proj(x)


class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
113 changes: 113 additions & 0 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
@@ -171,6 +171,46 @@ def forward(
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None

if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)

def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -203,6 +243,35 @@ def forward(
return x, gate_msa


class AdaLayerNormZeroSinglePruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).

Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()

if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa


class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
@@ -305,6 +374,50 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
return x


class AdaLayerNormContinuousPruned(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""

def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x


class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
122 changes: 99 additions & 23 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
@@ -33,22 +33,49 @@
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjChromaEmbeddings,
CombinedTimestepTextProjEmbeddings,
FluxPosEmbed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ..normalization import (
AdaLayerNormContinuous,
AdaLayerNormContinuousPruned,
AdaLayerNormZero,
AdaLayerNormZeroPruned,
AdaLayerNormZeroSingle,
AdaLayerNormZeroSinglePruned,
)


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."


@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
variant: str = "flux",
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)

self.norm = AdaLayerNormZeroSingle(dim)
if variant == "flux":
self.norm = AdaLayerNormZeroSingle(dim)
elif variant == "chroma":
self.norm = AdaLayerNormZeroSinglePruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)

self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
@@ -106,12 +133,24 @@ def forward(
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
variant: str = "flux",
):
super().__init__()

self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if variant == "flux":
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
elif variant == "chroma":
self.norm1 = AdaLayerNormZeroPruned(dim)
self.norm1_context = AdaLayerNormZeroPruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)

self.attn = Attention(
query_dim=dim,
@@ -141,10 +180,11 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)

norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
@@ -241,20 +281,35 @@ def __init__(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
variant: str = "flux",
approximator_in_factor: int = 16,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
if variant == "flux":
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
elif variant == "chroma":
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
factor=approximator_in_factor,
hidden_dim=approximator_hidden_dim,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
Copy link
Contributor

@Ednaordinary Ednaordinary May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be 3072? am I missing something? (currently computes to 344, but this doesn't fit distilled_guidance_layer.out_proj and comfy sets it to 3072) mod_index_length/mod_proj.shape[0] should be 344 though

embedding_dim=self.inner_dim,
n_layers=approximator_layers,
)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)

self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
@@ -265,6 +320,7 @@ def __init__(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
variant=variant,
)
for _ in range(num_layers)
]
@@ -276,16 +332,22 @@ def __init__(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
variant=variant,
)
for _ in range(num_single_layers)
]
)

self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

self.gradient_checkpointing = False

@property
def is_chroma(self) -> bool:
return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -442,19 +504,22 @@ def forward(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

is_chroma = self.is_chroma
hidden_states = self.x_embedder(hidden_states)

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000

if not is_chroma:
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
else:
guidance = None
pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections)

temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if txt_ids.ndim == 3:
@@ -479,6 +544,12 @@ def forward(
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

for index_block, block in enumerate(self.transformer_blocks):
if is_chroma:
start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block
start_idx2 = start_idx1 + 6 * len(self.transformer_blocks)
temb = torch.cat(
(pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
@@ -511,6 +582,9 @@ def forward(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if is_chroma:
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
@@ -538,6 +612,8 @@ def forward(

hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

if is_chroma:
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

24 changes: 22 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
@@ -191,6 +191,7 @@ def __init__(
transformer: FluxTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
variant: str = "flux",
):
super().__init__()

@@ -213,6 +214,17 @@ def __init__(
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 128
if variant not in {"flux", "chroma"}:
raise ValueError("`variant` must be `'flux' or `'chroma'`.")

self.variant = variant

def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device)
for i, n_tokens in enumerate(length):
n_tokens = torch.max(n_tokens + 1, max_sequence_length)
Copy link
Collaborator

@DN6 DN6 May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Chroma support tokens beyond the max length for T5? Wouldn't this operation result in a mask that is 512 tokens in length with all True/1 for n_tokens < max_sequence_length?

Also is it not possible to use the attention mask returned by the tokenizer? text_input_ids.attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, Chroma needs an attention mask that's equivalent to torch.cat([torch.ones(n_tokens + 1, dtype=torch.bool), torch.zeros(max_tokens - n_tokens - 1, dtype=torch.bool)]), because it needs one unmasked <pad> token after the actual prompt. IIUC text_input_ids.attention_mask will mask all the <pad> tokens.

The torch.max is to handle the corner case where n_tokens == max_sequence_length.

Copy link
Contributor

@Ednaordinary Ednaordinary May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.max doesnt work correctly here because it assumes max_sequence_length is the dimension (there are not 512 dimensions) but max() might work

though, the discussion here suggests the extra pad token is a mistake in the comfyui implementation so text_input_ids.attention_mask should be fine

Copy link
Contributor Author

@hameerabbasi hameerabbasi May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask[i, :n_tokens] = True
return attention_mask

def _get_t5_prompt_embeds(
self,
@@ -236,7 +248,7 @@ def _get_t5_prompt_embeds(
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_length=(self.variant == "chroma"),
return_overflowing_tokens=False,
return_tensors="pt",
)
@@ -250,7 +262,15 @@ def _get_t5_prompt_embeds(
f" {max_sequence_length} tokens: {removed_text}"
)

prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = self.text_encoder_2(
text_input_ids.to(device),
output_hidden_states=False,
attention_mask=(
self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device)
if self.variant == "chroma"
else None
),
)[0]

dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)