Skip to content

First Block Cache refactor #11594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: integrations/first-block-cache-2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 43 additions & 58 deletions src/diffusers/hooks/first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Tuple, Union

import torch

Expand Down Expand Up @@ -52,13 +51,19 @@ class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()

self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_output_hidden_states: torch.Tensor = None
self.head_block_output_encoder_hidden_states: torch.Tensor = None
self.head_block_residual_hidden_states: torch.Tensor = None
self.tail_block_residual_hidden_states: torch.Tensor = None
self.tail_block_residual_encoder_hidden_states: torch.Tensor = None
self.should_compute: bool = True

def reset(self):
self.tail_block_residuals = None
self.head_block_output_hidden_states = None
self.head_block_output_encoder_hidden_states = None
self.head_block_residual_hidden_states = None
self.tail_block_residual_hidden_states = None
self.tail_block_residual_encoder_hidden_states = None
self.should_compute = True


Expand All @@ -84,12 +89,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)

output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)

if is_output_tuple:
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
else:
hidden_states_residual = output - original_hidden_states
hidden_states_residual = output.hidden_states - original_hidden_states

shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
Expand All @@ -98,38 +98,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):

if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = shared_state.tail_block_residuals[0] + output

if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
return_output = output.__class__()
hidden_states = shared_state.tail_block_residual_hidden_states + output.hidden_states
return_output = return_output._replace(hidden_states=hidden_states)
if hasattr(output, "encoder_hidden_states"):
encoder_hidden_states = (
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
shared_state.tail_block_residual_encoder_hidden_states + output.encoder_hidden_states
)

if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return_output = tuple(return_output)
else:
return_output = hidden_states
output = return_output
return_output = return_output._replace(encoder_hidden_states=encoder_hidden_states)
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return_output = output
shared_state.head_block_output_hidden_states = output.hidden_states
if hasattr(output, "encoder_hidden_states"):
shared_state.head_block_output_encoder_hidden_states = output.encoder_hidden_states
shared_state.head_block_residual_hidden_states = hidden_states_residual

return output
return return_output

def reset_state(self, module):
self.state_manager.reset()
Expand All @@ -138,9 +122,9 @@ def reset_state(self, module):
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
if shared_state.head_block_residual_hidden_states is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual
prev_hidden_states_residual = shared_state.head_block_residual_hidden_states
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
Expand All @@ -153,6 +137,7 @@ def __init__(self, state_manager: StateManager, is_tail: bool = False):
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
self._output_cls = None

def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
Expand All @@ -166,37 +151,37 @@ def initialize_hook(self, module):

def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)

original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
try:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
except ValueError:
# This is expected for models that don't have use encoder_hidden_states in their forward definition
pass

shared_state = self.state_manager.get_state()

if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self._output_cls is None:
self._output_cls = output.__class__
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
hidden_states_residual = output.hidden_states - shared_state.head_block_output_hidden_states
if hasattr(output, "encoder_hidden_states"):
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
output.encoder_hidden_states - shared_state.head_block_output_encoder_hidden_states
)
else:
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
shared_state.tail_block_residual_hidden_states = hidden_states_residual
shared_state.tail_block_residual_encoder_hidden_states = encoder_hidden_states_residual
return output

if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
assert self._output_cls is not None
return_output = self._output_cls()
return_output = return_output._replace(hidden_states=original_hidden_states)
if hasattr(return_output, "encoder_hidden_states"):
return_output = return_output._replace(encoder_hidden_states=original_encoder_hidden_states)
return return_output


Expand Down
18 changes: 8 additions & 10 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -26,7 +26,6 @@
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..metadata import TransformerBlockMetadata, register_transformer_block
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
Expand All @@ -35,13 +34,12 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class CogVideoXBlockOutput(NamedTuple):
hidden_states: torch.Tensor = None
encoder_hidden_states: torch.Tensor = None


@maybe_allow_in_graph
@register_transformer_block(
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
)
)
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Expand Down Expand Up @@ -129,7 +127,7 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> CogVideoXBlockOutput:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}

Expand Down Expand Up @@ -161,7 +159,7 @@ def forward(
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]

return hidden_states, encoder_hidden_states
return CogVideoXBlockOutput(hidden_states, encoder_hidden_states)


class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
Expand Down
24 changes: 13 additions & 11 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -36,6 +36,11 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class WanTransformerBlockOutput(NamedTuple):
hidden_states: torch.Tensor = None
encoder_hidden_states: torch.Tensor = None


class WanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
Expand Down Expand Up @@ -222,12 +227,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


@maybe_allow_in_graph
@register_transformer_block(
TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
)
)
@register_transformer_block(TransformerBlockMetadata())
class WanTransformerBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -285,7 +285,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
) -> WanTransformerBlockOutput:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
Expand All @@ -307,7 +307,7 @@ def forward(
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)

return hidden_states
return WanTransformerBlockOutput(hidden_states, encoder_hidden_states)


class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
Expand Down Expand Up @@ -456,12 +456,14 @@ def forward(
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)

# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
Expand Down
Loading