Skip to content

Commit 8ceed7d

Browse files
committed
Initial commit: Chroma as a FLUX.1 variant.
1 parent 9836f0e commit 8ceed7d

File tree

3 files changed

+273
-25
lines changed

3 files changed

+273
-25
lines changed

src/diffusers/models/embeddings.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_timestep_embedding(
3131
downscale_freq_shift: float = 1,
3232
scale: float = 1,
3333
max_period: int = 10000,
34-
):
34+
) -> torch.Tensor:
3535
"""
3636
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
3737
@@ -1327,7 +1327,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
13271327
self.downscale_freq_shift = downscale_freq_shift
13281328
self.scale = scale
13291329

1330-
def forward(self, timesteps):
1330+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
13311331
t_emb = get_timestep_embedding(
13321332
timesteps,
13331333
self.num_channels,
@@ -1637,6 +1637,50 @@ def forward(self, timestep, guidance, pooled_projection):
16371637
return conditioning
16381638

16391639

1640+
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
1641+
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
1642+
super().__init__()
1643+
1644+
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
1645+
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
1646+
self.embedder = ChromaApproximator(
1647+
in_dim=factor * 4,
1648+
out_dim=out_dim,
1649+
hidden_dim=hidden_dim,
1650+
n_layers=n_layers,
1651+
)
1652+
self.embedding_dim = embedding_dim
1653+
1654+
self.register_buffer(
1655+
"mod_proj",
1656+
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
1657+
persistent=False,
1658+
)
1659+
1660+
def forward(
1661+
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
1662+
) -> torch.Tensor:
1663+
mod_index_length = self.mod_proj.shape[0]
1664+
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
1665+
if guidance is not None:
1666+
guidance_proj = self.guidance_proj(guidance)
1667+
else:
1668+
guidance_proj = torch.zeros(
1669+
(self.embedding_dim, self.guidance_proj.num_channels),
1670+
dtype=timesteps_proj.dtype,
1671+
device=timesteps_proj.device,
1672+
)
1673+
1674+
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
1675+
timestep_guidance = (
1676+
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
1677+
)
1678+
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
1679+
conditioning = self.embedder(input_vec)
1680+
1681+
return conditioning
1682+
1683+
16401684
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
16411685
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
16421686
super().__init__()
@@ -2230,6 +2274,25 @@ def forward(self, caption):
22302274
return hidden_states
22312275

22322276

2277+
class ChromaApproximator(nn.Module):
2278+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
2279+
super().__init__()
2280+
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
2281+
self.layers = nn.ModuleList(
2282+
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
2283+
)
2284+
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
2285+
self.out_proj = nn.Linear(hidden_dim, out_dim)
2286+
2287+
def forward(self, x):
2288+
x = self.in_proj(x)
2289+
2290+
for layer, norms in zip(self.layers, self.norms):
2291+
x = x + layer(norms(x))
2292+
2293+
return self.out_proj(x)
2294+
2295+
22332296
class IPAdapterPlusImageProjectionBlock(nn.Module):
22342297
def __init__(
22352298
self,

src/diffusers/models/normalization.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,46 @@ def forward(
171171
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
172172

173173

174+
class AdaLayerNormZeroPruned(nn.Module):
175+
r"""
176+
Norm layer adaptive layer norm zero (adaLN-Zero).
177+
178+
Parameters:
179+
embedding_dim (`int`): The size of each embedding vector.
180+
num_embeddings (`int`): The size of the embeddings dictionary.
181+
"""
182+
183+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
184+
super().__init__()
185+
if num_embeddings is not None:
186+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
187+
else:
188+
self.emb = None
189+
190+
if norm_type == "layer_norm":
191+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
192+
elif norm_type == "fp32_layer_norm":
193+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
194+
else:
195+
raise ValueError(
196+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
197+
)
198+
199+
def forward(
200+
self,
201+
x: torch.Tensor,
202+
timestep: Optional[torch.Tensor] = None,
203+
class_labels: Optional[torch.LongTensor] = None,
204+
hidden_dtype: Optional[torch.dtype] = None,
205+
emb: Optional[torch.Tensor] = None,
206+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
207+
if self.emb is not None:
208+
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
209+
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
210+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
211+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
212+
213+
174214
class AdaLayerNormZeroSingle(nn.Module):
175215
r"""
176216
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -203,6 +243,35 @@ def forward(
203243
return x, gate_msa
204244

205245

246+
class AdaLayerNormZeroSinglePruned(nn.Module):
247+
r"""
248+
Norm layer adaptive layer norm zero (adaLN-Zero).
249+
250+
Parameters:
251+
embedding_dim (`int`): The size of each embedding vector.
252+
num_embeddings (`int`): The size of the embeddings dictionary.
253+
"""
254+
255+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
256+
super().__init__()
257+
258+
if norm_type == "layer_norm":
259+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
260+
else:
261+
raise ValueError(
262+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
263+
)
264+
265+
def forward(
266+
self,
267+
x: torch.Tensor,
268+
emb: Optional[torch.Tensor] = None,
269+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
270+
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
271+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
272+
return x, gate_msa
273+
274+
206275
class LuminaRMSNormZero(nn.Module):
207276
"""
208277
Norm layer adaptive RMS normalization zero.
@@ -305,6 +374,50 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
305374
return x
306375

307376

377+
class AdaLayerNormContinuousPruned(nn.Module):
378+
r"""
379+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
380+
381+
Args:
382+
embedding_dim (`int`): Embedding dimension to use during projection.
383+
conditioning_embedding_dim (`int`): Dimension of the input condition.
384+
elementwise_affine (`bool`, defaults to `True`):
385+
Boolean flag to denote if affine transformation should be applied.
386+
eps (`float`, defaults to 1e-5): Epsilon factor.
387+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
388+
norm_type (`str`, defaults to `"layer_norm"`):
389+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
390+
"""
391+
392+
def __init__(
393+
self,
394+
embedding_dim: int,
395+
conditioning_embedding_dim: int,
396+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
397+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
398+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
399+
# However, this is how it was implemented in the original code, and it's rather likely you should
400+
# set `elementwise_affine` to False.
401+
elementwise_affine=True,
402+
eps=1e-5,
403+
bias=True,
404+
norm_type="layer_norm",
405+
):
406+
super().__init__()
407+
if norm_type == "layer_norm":
408+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
409+
elif norm_type == "rms_norm":
410+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
411+
else:
412+
raise ValueError(f"unknown norm_type {norm_type}")
413+
414+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
415+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
416+
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
417+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
418+
return x
419+
420+
308421
class AdaLayerNormContinuous(nn.Module):
309422
r"""
310423
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

0 commit comments

Comments
 (0)