@@ -171,6 +171,46 @@ def forward(
171
171
return x , gate_msa , shift_mlp , scale_mlp , gate_mlp
172
172
173
173
174
+ class AdaLayerNormZeroPruned (nn .Module ):
175
+ r"""
176
+ Norm layer adaptive layer norm zero (adaLN-Zero).
177
+
178
+ Parameters:
179
+ embedding_dim (`int`): The size of each embedding vector.
180
+ num_embeddings (`int`): The size of the embeddings dictionary.
181
+ """
182
+
183
+ def __init__ (self , embedding_dim : int , num_embeddings : Optional [int ] = None , norm_type = "layer_norm" , bias = True ):
184
+ super ().__init__ ()
185
+ if num_embeddings is not None :
186
+ self .emb = CombinedTimestepLabelEmbeddings (num_embeddings , embedding_dim )
187
+ else :
188
+ self .emb = None
189
+
190
+ if norm_type == "layer_norm" :
191
+ self .norm = nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
192
+ elif norm_type == "fp32_layer_norm" :
193
+ self .norm = FP32LayerNorm (embedding_dim , elementwise_affine = False , bias = False )
194
+ else :
195
+ raise ValueError (
196
+ f"Unsupported `norm_type` ({ norm_type } ) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
197
+ )
198
+
199
+ def forward (
200
+ self ,
201
+ x : torch .Tensor ,
202
+ timestep : Optional [torch .Tensor ] = None ,
203
+ class_labels : Optional [torch .LongTensor ] = None ,
204
+ hidden_dtype : Optional [torch .dtype ] = None ,
205
+ emb : Optional [torch .Tensor ] = None ,
206
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
207
+ if self .emb is not None :
208
+ emb = self .emb (timestep , class_labels , hidden_dtype = hidden_dtype )
209
+ scale_msa , shift_msa , gate_msa , scale_mlp , shift_mlp , gate_mlp = emb .chunk (6 , dim = 1 )
210
+ x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
211
+ return x , gate_msa , shift_mlp , scale_mlp , gate_mlp
212
+
213
+
174
214
class AdaLayerNormZeroSingle (nn .Module ):
175
215
r"""
176
216
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -203,6 +243,35 @@ def forward(
203
243
return x , gate_msa
204
244
205
245
246
+ class AdaLayerNormZeroSinglePruned (nn .Module ):
247
+ r"""
248
+ Norm layer adaptive layer norm zero (adaLN-Zero).
249
+
250
+ Parameters:
251
+ embedding_dim (`int`): The size of each embedding vector.
252
+ num_embeddings (`int`): The size of the embeddings dictionary.
253
+ """
254
+
255
+ def __init__ (self , embedding_dim : int , norm_type = "layer_norm" , bias = True ):
256
+ super ().__init__ ()
257
+
258
+ if norm_type == "layer_norm" :
259
+ self .norm = nn .LayerNorm (embedding_dim , elementwise_affine = False , eps = 1e-6 )
260
+ else :
261
+ raise ValueError (
262
+ f"Unsupported `norm_type` ({ norm_type } ) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
263
+ )
264
+
265
+ def forward (
266
+ self ,
267
+ x : torch .Tensor ,
268
+ emb : Optional [torch .Tensor ] = None ,
269
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
270
+ scale_msa , shift_msa , gate_msa = emb .chunk (3 , dim = 1 )
271
+ x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
272
+ return x , gate_msa
273
+
274
+
206
275
class LuminaRMSNormZero (nn .Module ):
207
276
"""
208
277
Norm layer adaptive RMS normalization zero.
@@ -305,6 +374,50 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
305
374
return x
306
375
307
376
377
+ class AdaLayerNormContinuousPruned (nn .Module ):
378
+ r"""
379
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
380
+
381
+ Args:
382
+ embedding_dim (`int`): Embedding dimension to use during projection.
383
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
384
+ elementwise_affine (`bool`, defaults to `True`):
385
+ Boolean flag to denote if affine transformation should be applied.
386
+ eps (`float`, defaults to 1e-5): Epsilon factor.
387
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
388
+ norm_type (`str`, defaults to `"layer_norm"`):
389
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
390
+ """
391
+
392
+ def __init__ (
393
+ self ,
394
+ embedding_dim : int ,
395
+ conditioning_embedding_dim : int ,
396
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
397
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
398
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
399
+ # However, this is how it was implemented in the original code, and it's rather likely you should
400
+ # set `elementwise_affine` to False.
401
+ elementwise_affine = True ,
402
+ eps = 1e-5 ,
403
+ bias = True ,
404
+ norm_type = "layer_norm" ,
405
+ ):
406
+ super ().__init__ ()
407
+ if norm_type == "layer_norm" :
408
+ self .norm = LayerNorm (embedding_dim , eps , elementwise_affine , bias )
409
+ elif norm_type == "rms_norm" :
410
+ self .norm = RMSNorm (embedding_dim , eps , elementwise_affine )
411
+ else :
412
+ raise ValueError (f"unknown norm_type { norm_type } " )
413
+
414
+ def forward (self , x : torch .Tensor , emb : torch .Tensor ) -> torch .Tensor :
415
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
416
+ shift , scale = torch .chunk (emb .to (x .dtype ), 2 , dim = 1 )
417
+ x = self .norm (x ) * (1 + scale )[:, None , :] + shift [:, None , :]
418
+ return x
419
+
420
+
308
421
class AdaLayerNormContinuous (nn .Module ):
309
422
r"""
310
423
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
0 commit comments