diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index e2e27048cc61..77e2687dc895 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -13,7 +13,6 @@ # limitations under the License. from dataclasses import dataclass -from typing import Tuple, Union import torch @@ -52,13 +51,19 @@ class FBCSharedBlockState(BaseState): def __init__(self) -> None: super().__init__() - self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None - self.head_block_residual: torch.Tensor = None - self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_output_hidden_states: torch.Tensor = None + self.head_block_output_encoder_hidden_states: torch.Tensor = None + self.head_block_residual_hidden_states: torch.Tensor = None + self.tail_block_residual_hidden_states: torch.Tensor = None + self.tail_block_residual_encoder_hidden_states: torch.Tensor = None self.should_compute: bool = True def reset(self): - self.tail_block_residuals = None + self.head_block_output_hidden_states = None + self.head_block_output_encoder_hidden_states = None + self.head_block_residual_hidden_states = None + self.tail_block_residual_hidden_states = None + self.tail_block_residual_encoder_hidden_states = None self.should_compute = True @@ -84,12 +89,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) output = self.fn_ref.original_forward(*args, **kwargs) - is_output_tuple = isinstance(output, tuple) - - if is_output_tuple: - hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states - else: - hidden_states_residual = output - original_hidden_states + hidden_states_residual = output.hidden_states - original_hidden_states shared_state: FBCSharedBlockState = self.state_manager.get_state() hidden_states = encoder_hidden_states = None @@ -98,38 +98,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if not should_compute: # Apply caching - if is_output_tuple: - hidden_states = ( - shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] - ) - else: - hidden_states = shared_state.tail_block_residuals[0] + output - - if self._metadata.return_encoder_hidden_states_index is not None: - assert is_output_tuple + return_output = output.__class__() + hidden_states = shared_state.tail_block_residual_hidden_states + output.hidden_states + return_output = return_output._replace(hidden_states=hidden_states) + if hasattr(output, "encoder_hidden_states"): encoder_hidden_states = ( - shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] + shared_state.tail_block_residual_encoder_hidden_states + output.encoder_hidden_states ) - - if is_output_tuple: - return_output = [None] * len(output) - return_output[self._metadata.return_hidden_states_index] = hidden_states - return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states - return_output = tuple(return_output) - else: - return_output = hidden_states - output = return_output + return_output = return_output._replace(encoder_hidden_states=encoder_hidden_states) else: - if is_output_tuple: - head_block_output = [None] * len(output) - head_block_output[0] = output[self._metadata.return_hidden_states_index] - head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] - else: - head_block_output = output - shared_state.head_block_output = head_block_output - shared_state.head_block_residual = hidden_states_residual + return_output = output + shared_state.head_block_output_hidden_states = output.hidden_states + if hasattr(output, "encoder_hidden_states"): + shared_state.head_block_output_encoder_hidden_states = output.encoder_hidden_states + shared_state.head_block_residual_hidden_states = hidden_states_residual - return output + return return_output def reset_state(self, module): self.state_manager.reset() @@ -138,9 +122,9 @@ def reset_state(self, module): @torch.compiler.disable def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: shared_state = self.state_manager.get_state() - if shared_state.head_block_residual is None: + if shared_state.head_block_residual_hidden_states is None: return True - prev_hidden_states_residual = shared_state.head_block_residual + prev_hidden_states_residual = shared_state.head_block_residual_hidden_states absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() diff = (absmean / prev_hidden_states_absmean).item() @@ -153,6 +137,7 @@ def __init__(self, state_manager: StateManager, is_tail: bool = False): self.state_manager = state_manager self.is_tail = is_tail self._metadata = None + self._output_cls = None def initialize_hook(self, module): unwrapped_module = unwrap_module(module) @@ -166,37 +151,37 @@ def initialize_hook(self, module): def new_forward(self, module: torch.nn.Module, *args, **kwargs): original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + original_encoder_hidden_states = None - if self._metadata.return_encoder_hidden_states_index is not None: + try: original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( "encoder_hidden_states", args, kwargs ) + except ValueError: + # This is expected for models that don't have use encoder_hidden_states in their forward definition + pass shared_state = self.state_manager.get_state() if shared_state.should_compute: output = self.fn_ref.original_forward(*args, **kwargs) + if self._output_cls is None: + self._output_cls = output.__class__ if self.is_tail: - hidden_states_residual = encoder_hidden_states_residual = None - if isinstance(output, tuple): - hidden_states_residual = ( - output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] - ) + hidden_states_residual = output.hidden_states - shared_state.head_block_output_hidden_states + if hasattr(output, "encoder_hidden_states"): encoder_hidden_states_residual = ( - output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] + output.encoder_hidden_states - shared_state.head_block_output_encoder_hidden_states ) - else: - hidden_states_residual = output - shared_state.head_block_output - shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) + shared_state.tail_block_residual_hidden_states = hidden_states_residual + shared_state.tail_block_residual_encoder_hidden_states = encoder_hidden_states_residual return output - if original_encoder_hidden_states is None: - return_output = original_hidden_states - else: - return_output = [None, None] - return_output[self._metadata.return_hidden_states_index] = original_hidden_states - return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states - return_output = tuple(return_output) + assert self._output_cls is not None + return_output = self._output_cls() + return_output = return_output._replace(hidden_states=original_hidden_states) + if hasattr(return_output, "encoder_hidden_states"): + return_output = return_output._replace(encoder_hidden_states=original_encoder_hidden_states) return return_output diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 4561cbf505c9..fdc0b68c3ac0 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import torch from torch import nn @@ -26,7 +26,6 @@ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -35,13 +34,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class CogVideoXBlockOutput(NamedTuple): + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor = None + + @maybe_allow_in_graph -@register_transformer_block( - metadata=TransformerBlockMetadata( - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ) -) class CogVideoXBlock(nn.Module): r""" Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. @@ -129,7 +127,7 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: + ) -> CogVideoXBlockOutput: text_seq_length = encoder_hidden_states.size(1) attention_kwargs = attention_kwargs or {} @@ -161,7 +159,7 @@ def forward( hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] - return hidden_states, encoder_hidden_states + return CogVideoXBlockOutput(hidden_states, encoder_hidden_states) class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4ab26b90b326..2bf70c281045 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn @@ -36,6 +36,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanTransformerBlockOutput(NamedTuple): + hidden_states: torch.Tensor = None + encoder_hidden_states: torch.Tensor = None + + class WanAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -222,12 +227,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @maybe_allow_in_graph -@register_transformer_block( - TransformerBlockMetadata( - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ) -) +@register_transformer_block(TransformerBlockMetadata()) class WanTransformerBlock(nn.Module): def __init__( self, @@ -285,7 +285,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, rotary_emb: torch.Tensor, - ) -> torch.Tensor: + ) -> WanTransformerBlockOutput: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1) @@ -307,7 +307,7 @@ def forward( ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - return hidden_states + return WanTransformerBlockOutput(hidden_states, encoder_hidden_states) class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): @@ -456,12 +456,14 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: - hidden_states = self._gradient_checkpointing_func( + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) else: for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)