diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 990c90512e39..03284bc2a624 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -23,11 +23,12 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code: @@ -149,8 +150,13 @@ def __init__( # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) # 5. Latte other blocks. @@ -165,6 +171,17 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, @@ -305,10 +322,7 @@ def forward( embedded_timestep = embedded_timestep.repeat_interleave( num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame ).view(-1, embedded_timestep.shape[-1]) - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) # unpatchify diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 40a14bfd9b27..03fb1f8c30b8 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -23,7 +23,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"] _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config @@ -171,8 +171,13 @@ def __init__( ) # 3. Output blocks. - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) self.adaln_single = AdaLayerNormSingle( @@ -184,6 +189,17 @@ def __init__( in_features=self.config.caption_channels, hidden_size=self.inner_dim ) + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -406,12 +422,7 @@ def forward( ) # 3. Output - shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) - ).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5fa59a71d977..fae075985935 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -28,7 +28,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -175,6 +175,7 @@ def forward( class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data. @@ -292,8 +293,13 @@ def __init__( ) # 3. Output projection & norm - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings @@ -304,6 +310,17 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, @@ -393,11 +410,7 @@ def forward( ) # 4. Output normalization & projection - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2d06124282d1..44c17f9fa8ee 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle, RMSNorm +from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] + _no_split_modules = ["norm_out"] _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config @@ -356,7 +357,6 @@ def __init__( self.proj_in = nn.Linear(in_channels, inner_dim) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -389,11 +389,40 @@ def __init__( ] ) - self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels) self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(LTXVideoTransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, @@ -464,11 +493,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1)) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index bdb9201e62cf..aea524491557 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -29,7 +29,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -372,7 +372,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] @@ -428,12 +428,40 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(WanTransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, @@ -488,16 +516,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1a6f2af59a87..e25e75590829 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -26,7 +26,7 @@ from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock @@ -179,7 +179,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -259,12 +259,40 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(WanVACETransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, @@ -365,16 +393,7 @@ def forward( hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(