From 406b1062f8274b9551058fa1dc79ab62519770fc Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 31 Mar 2025 04:27:35 +0200 Subject: [PATCH 01/22] update --- docs/source/en/api/cache.md | 81 ++++--- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 15 ++ src/diffusers/hooks/_common.py | 30 +++ src/diffusers/hooks/_helpers.py | 199 ++++++++++++++++ src/diffusers/hooks/first_block_cache.py | 220 ++++++++++++++++++ src/diffusers/models/cache_utils.py | 26 ++- .../models/transformers/transformer_ltx.py | 3 +- src/diffusers/utils/dummy_pt_objects.py | 19 ++ tests/pipelines/cogvideo/test_cogvideox.py | 7 +- tests/pipelines/flux/test_pipeline_flux.py | 4 +- .../hunyuan_video/test_hunyuan_video.py | 7 +- tests/pipelines/ltx/test_ltx.py | 8 +- tests/pipelines/mochi/test_mochi.py | 6 +- tests/pipelines/test_pipelines_common.py | 52 ++++- 15 files changed, 632 insertions(+), 49 deletions(-) create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/first_block_cache.py diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index a6aa5445a845..a1d961cc2974 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -11,6 +11,50 @@ specific language governing permissions and limitations under the License. --> # Caching methods +## Faster Cache + +[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. + +FasterCache is a method that speeds up inference in diffusion transformers by: +- Reusing attention states between successive inference steps, due to high similarity between them +- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + +```python +import torch +from diffusers import CogVideoXPipeline, FasterCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 681), + current_timestep_callback=lambda: pipe.current_timestep, + attention_weight_callback=lambda _: 0.3, + unconditional_batch_skip_range=5, + unconditional_batch_timestep_skip_range=(-1, 781), + tensor_format="BFCHW", +) +pipe.transformer.enable_cache(config) +``` + +## First Block Cache + +[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual. + +```python +import torch +from diffusers import CogVideoXPipeline, FirstBlockCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos. +# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05. +config = FirstBlockCacheConfig(threshold=0.07) +pipe.transformer.enable_cache(config) +``` + ## Pyramid Attention Broadcast [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. @@ -38,45 +82,24 @@ config = PyramidAttentionBroadcastConfig( pipe.transformer.enable_cache(config) ``` -## Faster Cache +### CacheMixin -[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. +[[autodoc]] CacheMixin -FasterCache is a method that speeds up inference in diffusion transformers by: -- Reusing attention states between successive inference steps, due to high similarity between them -- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output +### FasterCacheConfig -```python -import torch -from diffusers import CogVideoXPipeline, FasterCacheConfig +[[autodoc]] FasterCacheConfig -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) -pipe.to("cuda") +[[autodoc]] apply_faster_cache -config = FasterCacheConfig( - spatial_attention_block_skip_range=2, - spatial_attention_timestep_skip_range=(-1, 681), - current_timestep_callback=lambda: pipe.current_timestep, - attention_weight_callback=lambda _: 0.3, - unconditional_batch_skip_range=5, - unconditional_batch_timestep_skip_range=(-1, 781), - tensor_format="BFCHW", -) -pipe.transformer.enable_cache(config) -``` +### FirstBlockCacheConfig -### CacheMixin +[[autodoc]] FirstBlockCacheConfig -[[autodoc]] CacheMixin +[[autodoc]] apply_first_block_cache ### PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast - -### FasterCacheConfig - -[[autodoc]] FasterCacheConfig - -[[autodoc]] apply_faster_cache diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 656f9b27db90..2c7372baa678 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,9 +132,11 @@ _import_structure["hooks"].extend( [ "FasterCacheConfig", + "FirstBlockCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", "apply_faster_cache", + "apply_first_block_cache", "apply_pyramid_attention_broadcast", ] ) @@ -709,9 +711,11 @@ else: from .hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..365bed371864 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -1,8 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ..utils import is_torch_available if is_torch_available(): from .faster_cache import FasterCacheConfig, apply_faster_cache + from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3be77dd4cedf --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..606a58cd578e --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,199 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_transformer_blocks_metadata(): + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 000000000000..1f1bfd6c8cf9 --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,220 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in lower number of forward passes and faster inference, but might lead to + poorer generation quality. A lower threshold may not result in significant generation speedup. The + threshold is compared against the absmean difference of the residuals between the current and cached + outputs from the first transformer block. If the difference is below the threshold, the forward pass is + skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState: + def __init__(self) -> None: + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, shared_state: FBCSharedBlockState, threshold: float): + self.shared_state = shared_state + self.threshold = threshold + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + + output = self.fn_ref.original_forward(*args, **kwargs) + is_output_tuple = isinstance(output, tuple) + + hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + hs, ehs = None, None + + should_compute = self._should_compute_remaining_blocks(hs_residual) + self.shared_state.should_compute = should_compute + + if not should_compute: + # Apply caching + logger.info("Skipping forward pass through remaining blocks") + + if is_output_tuple: + hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + else: + hs = output + + if self._metadata.return_encoder_hidden_states_index is not None: + ehs = ( + self.shared_state.tail_block_residuals[1] + + output[self._metadata.return_encoder_hidden_states_index] + ) + + if is_output_tuple: + return_output = [None] * len(output) + return_output[self._metadata.return_hidden_states_index] = hs + return_output[self._metadata.return_encoder_hidden_states_index] = ehs + return_output = tuple(return_output) + else: + return_output = hs + return return_output + else: + logger.info("Computing forward pass through remaining blocks") + if is_output_tuple: + head_block_output = [None] * len(output) + head_block_output[0] = output[self._metadata.return_hidden_states_index] + head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] + else: + head_block_output = output + self.shared_state.head_block_output = head_block_output + self.shared_state.head_block_residual = hs_residual + return output + + def reset_state(self, module): + self.shared_state.reset() + return module + + def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: + if self.shared_state.head_block_residual is None: + return True + prev_hs_residual = self.shared_state.head_block_residual + hs_absmean = (hs_residual - prev_hs_residual).abs().mean() + prev_hs_mean = prev_hs_residual.abs().mean() + diff = (hs_absmean / prev_hs_mean).item() + logger.info(f"Diff: {diff}, Threshold: {self.threshold}") + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + super().__init__() + self.shared_state = shared_state + self.is_tail = is_tail + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_ehs = None + if self._metadata.return_encoder_hidden_states_index is not None: + original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + + if self.shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hs_residual, ehs_residual = None, None + if isinstance(output, tuple): + hs_residual = ( + output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] + ) + ehs_residual = ( + output[self._metadata.return_encoder_hidden_states_index] + - self.shared_state.head_block_output[1] + ) + else: + if isinstance(self.shared_state.head_block_output, list): + # For cases where double blocks returning list is followed by single blocks returning single value (Flux) + hs_residual = output - self.shared_state.head_block_output[0] + else: + hs_residual = output - self.shared_state.head_block_output + self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) + return output + + output_count = len(outputs_if_skipped) if isinstance(outputs_if_skipped, tuple) else 1 + if output_count == 1: + return_output = original_hs + else: + return_output = [None] * output_count + return_output[self._metadata.return_hidden_states_index] = original_hs + return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + shared_state = FBCSharedBlockState() + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for block in submodule: + remaining_blocks.append((name, block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") + apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Apply FBCBlockHook to '{name}'") + apply_fbc_block_hook(block, shared_state) + + logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") + apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + + +def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 79bd8dc0b254..6d0192239ec5 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -25,6 +25,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) + - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) """ _cache_config = None @@ -62,8 +63,10 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) @@ -72,31 +75,36 @@ def enable_cache(self, config) -> None: f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." ) - if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FasterCacheConfig): + if isinstance(config, FasterCacheConfig): apply_faster_cache(self, config) + elif isinstance(config, FirstBlockCacheConfig): + apply_first_block_cache(self, config) + elif isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") return - if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FasterCacheConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) + registry = HookRegistry.check_if_exists_or_initialize(self) + if isinstance(self._cache_config, FasterCacheConfig): registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, FirstBlockCacheConfig): + registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c1f2df587927..2ae2418098f6 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -26,6 +26,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -298,7 +299,7 @@ def forward( @maybe_allow_in_graph -class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6edbd737e32c..dfbac9512e91 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FirstBlockCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_first_block_cache(*args, **kwargs): + requires_backends(apply_first_block_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 388dc9ef7ec4..385984f0b497 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -44,7 +45,11 @@ class CogVideoXPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 6a560367a5b8..b9795fc20b1f 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -25,6 +25,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -34,11 +35,12 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index aa4f045966c3..e6587520c932 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np, @@ -43,7 +44,11 @@ class HunyuanVideoPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 4f72729fc9ce..1f94b746f12f 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -23,13 +23,13 @@ from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase): pipeline_class = LTXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LTXVideoTransformer3DModel( in_channels=8, @@ -59,7 +59,7 @@ def get_dummy_components(self): num_attention_heads=4, attention_head_dim=8, cross_attention_dim=32, - num_layers=1, + num_layers=num_layers, caption_channels=32, ) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index ea2d015af52a..ce052962e511 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -33,13 +33,15 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class MochiPipelineFastTests( + PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase +): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d069def66ecf..08029419de3b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -33,6 +33,7 @@ ) from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2608,7 +2609,7 @@ def run_forward(pipe): self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep pipe = create_pipe() pipe.transformer.enable_cache(self.faster_cache_config) - output = run_forward(pipe).flatten().flatten() + output = run_forward(pipe).flatten() image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) # Run inference with FasterCache disabled @@ -2715,6 +2716,55 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs): self.assertTrue(state.cache is None, "Cache should be reset to None.") +# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out +# of the box once there is better cache support/implementation +class FirstBlockCacheTesterMixin: + # threshold is intentionally set higher than usual values since we're testing with random unconverged models + # that will not satisfy the expected properties of the denoiser for caching to be effective + first_block_cache_config = FirstBlockCacheConfig(threshold=0.8) + + def test_first_block_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FirstBlockCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.first_block_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose( + original_image_slice, image_slice_fbc_enabled, atol=expected_atol + ), "FirstBlockCache outputs should not differ much." + assert np.allclose( + original_image_slice, image_slice_fbc_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From dd69b418349bb923155d11371185a0424a1c0041 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 01:28:09 +0200 Subject: [PATCH 02/22] modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) --- src/diffusers/hooks/_helpers.py | 6 +++--- src/diffusers/hooks/first_block_cache.py | 6 +----- .../models/transformers/transformer_flux.py | 21 ++++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 606a58cd578e..253ca88059e5 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -84,8 +84,8 @@ def _register_transformer_blocks_metadata(): model_class=FluxSingleTransformerBlock, metadata=TransformerBlockMetadata( skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, ), ) @@ -185,7 +185,7 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en _skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states _skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 1f1bfd6c8cf9..b440af0faed2 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -166,11 +166,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): - self.shared_state.head_block_output[1] ) else: - if isinstance(self.shared_state.head_block_output, list): - # For cases where double blocks returning list is followed by single blocks returning single value (Flux) - hs_residual = output - self.shared_state.head_block_output[0] - else: - hs_residual = output - self.shared_state.head_block_output + hs_residual = output - self.shared_state.head_block_output self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) return output diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..b0fb3900f657 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -100,7 +104,8 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states @maybe_allow_in_graph @@ -508,20 +513,21 @@ def forward( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -531,12 +537,7 @@ def forward( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) From 7ab424a15a9f89fc7679cd72a22a1a756959be4e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 01:39:00 +0200 Subject: [PATCH 03/22] remove debug logs --- src/diffusers/hooks/first_block_cache.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index b440af0faed2..cdc08b4a4c9f 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -87,8 +87,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if not should_compute: # Apply caching - logger.info("Skipping forward pass through remaining blocks") - if is_output_tuple: hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] else: @@ -109,7 +107,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return_output = hs return return_output else: - logger.info("Computing forward pass through remaining blocks") if is_output_tuple: head_block_output = [None] * len(output) head_block_output[0] = output[self._metadata.return_hidden_states_index] @@ -131,7 +128,6 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() - logger.info(f"Diff: {diff}, Threshold: {self.threshold}") return diff > self.threshold From d71fe55895c4503b313a2d0c0740b50d51e5eb31 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 17:06:45 +0200 Subject: [PATCH 04/22] update --- src/diffusers/hooks/first_block_cache.py | 14 +++++++++----- .../models/transformers/transformer_wan.py | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index cdc08b4a4c9f..f1b150ac75d7 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -105,7 +105,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return_output = tuple(return_output) else: return_output = hs - return return_output + output = return_output else: if is_output_tuple: head_block_output = [None] * len(output) @@ -115,12 +115,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): head_block_output = output self.shared_state.head_block_output = head_block_output self.shared_state.head_block_residual = hs_residual - return output + + return output def reset_state(self, module): self.shared_state.reset() return module + @torch.compiler.disable def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: if self.shared_state.head_block_residual is None: return True @@ -144,6 +146,8 @@ def initialize_hook(self, module): def new_forward(self, module: torch.nn.Module, *args, **kwargs): outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + if not isinstance(outputs_if_skipped, tuple): + outputs_if_skipped = (outputs_if_skipped,) original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] original_ehs = None if self._metadata.return_encoder_hidden_states_index is not None: @@ -166,7 +170,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) return output - output_count = len(outputs_if_skipped) if isinstance(outputs_if_skipped, tuple) else 1 + output_count = len(outputs_if_skipped) if output_count == 1: return_output = original_hs else: @@ -183,8 +187,8 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf for name, submodule in module.named_children(): if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): continue - for block in submodule: - remaining_blocks.append((name, block)) + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) head_block_name, head_block = remaining_blocks.pop(0) tail_block_name, tail_block = remaining_blocks.pop(-1) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4eb4add37601..aa03e97093aa 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -288,7 +289,7 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan model. From 2557238b4d33ea60b6c5e1829c065a132aa9c9aa Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 19:40:23 +0200 Subject: [PATCH 05/22] cache context for different batches of data --- src/diffusers/hooks/first_block_cache.py | 7 +- src/diffusers/hooks/hooks.py | 88 +++++++++++++++++++ src/diffusers/models/cache_utils.py | 20 +++++ .../pipelines/cogview4/pipeline_cogview4.py | 4 +- .../hunyuan_video/pipeline_hunyuan_video.py | 4 +- src/diffusers/pipelines/wan/pipeline_wan.py | 4 +- 6 files changed, 122 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index f1b150ac75d7..306825800e76 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -20,7 +20,7 @@ from ..utils import get_logger from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import HookRegistry, ModelHook +from .hooks import BaseMarkedState, HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name @@ -48,8 +48,10 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState: +class FBCSharedBlockState(BaseMarkedState): def __init__(self) -> None: + super().__init__() + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None self.head_block_residual: torch.Tensor = None self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None @@ -130,6 +132,7 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() + print("diff:", self.shared_state._mark_name, diff, flush=True) return diff > self.threshold diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 3b2e4ed91c2f..9e8128d0bb18 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -23,6 +23,70 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +class BaseState: + def reset(self, *args, **kwargs) -> None: + raise NotImplementedError( + "BaseState::reset is not implemented. Please implement this method in the derived class." + ) + + +class BaseMarkedState(BaseState): + def __init__(self, init_args=None, init_kwargs=None): + super().__init__() + + self._init_args = init_args if init_args is not None else () + self._init_kwargs = init_kwargs if init_kwargs is not None else {} + self._mark_name = None + self._state_cache = {} + + def get_current_state(self) -> "BaseMarkedState": + if self._mark_name is None: + # If no mark name is set, simply return a dummy object since we're not going to be using it + return self + if self._mark_name not in self._state_cache.keys(): + self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) + return self._state_cache[self._mark_name] + + def mark_batch(self, name: str) -> None: + self._mark_name = name + + def reset(self, *args, **kwargs) -> None: + for name, state in list(self._state_cache.items()): + state.reset(*args, **kwargs) + self._state_cache.pop(name) + self._mark_name = None + + def __getattribute__(self, name): + if name in ( + "get_current_state", + "mark_batch", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + return object.__getattribute__(self, name) + else: + current_state = BaseMarkedState.get_current_state(self) + return object.__getattribute__(current_state, name) + + def __setattr__(self, name, value): + if name in ( + "get_current_state", + "mark_batch", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + object.__setattr__(self, name, value) + else: + current_state = BaseMarkedState.get_current_state(self) + object.__setattr__(current_state, name, value) + + class ModelHook: r""" A hook that contains callbacks to be executed just before and after the forward method of a model. @@ -99,6 +163,14 @@ def reset_state(self, module: torch.nn.Module): raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module + def _mark_state(self, module: torch.nn.Module, name: str) -> None: + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them. + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, BaseMarkedState): + attr.mark_batch(name) + return module + class HookFunctionReference: def __init__(self) -> None: @@ -223,6 +295,18 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry module._diffusers_hook = cls(module) return module._diffusers_hook + def _mark_state(self, name: str) -> None: + for hook_name in reversed(self._hook_order): + hook = self.hooks[hook_name] + if hook._is_stateful: + hook._mark_state(self._module_ref, name) + + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._mark_state(name) + def __repr__(self) -> str: registry_repr = "" for i, hook_name in enumerate(self._hook_order): @@ -234,3 +318,7 @@ def __repr__(self) -> str: if i < len(self._hook_order) - 1: registry_repr += "\n" return f"HookRegistry(\n{registry_repr}\n)" + + +def _is_dunder_method(name: str) -> bool: + return name.startswith("__") and name.endswith("__") and name in dir(object) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 6d0192239ec5..6c4bcb301d70 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + from ..utils.logging import get_logger @@ -114,3 +116,21 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None: from ..hooks import HookRegistry HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) + + @contextmanager + def _cache_context(self): + r"""Context manager that provides additional methods for cache management.""" + cache_context = _CacheContextManager(self) + yield cache_context + + +class _CacheContextManager: + def __init__(self, model: CacheMixin): + self.model = model + + def mark_state(self, name: str) -> None: + from ..hooks import HookRegistry + + if self.model.is_cache_enabled: + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry._mark_state(name) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index c27a1a19774d..6cf74ac5d942 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -610,7 +610,7 @@ def __call__( transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -621,6 +621,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) + cc.mark_state("cond") noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, @@ -634,6 +635,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: + cc.mark_state("uncond") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3cb91b3782f2..b36de61c02ef 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -693,6 +693,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -705,6 +706,7 @@ def __call__( )[0] if do_true_cfg: + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 6fab997e6660..733d79b5ac2c 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -512,7 +512,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -521,6 +521,7 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -530,6 +531,7 @@ def __call__( )[0] if self.do_classifier_free_guidance: + cc.mark_state("uncond") noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, From 0e232ac8c0a922e2caf137641b29b7b2cc59a529 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 00:38:11 +0200 Subject: [PATCH 06/22] fix hs residual bug for single return outputs; support ltx --- src/diffusers/hooks/first_block_cache.py | 11 +++++++---- src/diffusers/pipelines/ltx/pipeline_ltx.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 306825800e76..1293ded558f0 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -81,9 +81,12 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) - hs_residual = output[self._metadata.return_hidden_states_index] - original_hs - hs, ehs = None, None + if is_output_tuple: + hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + else: + hs_residual = output - original_hs + hs, ehs = None, None should_compute = self._should_compute_remaining_blocks(hs_residual) self.shared_state.should_compute = should_compute @@ -92,9 +95,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if is_output_tuple: hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] else: - hs = output + hs = self.shared_state.tail_block_residuals[0] + output if self._metadata.return_encoder_hidden_states_index is not None: + assert is_output_tuple ehs = ( self.shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] @@ -132,7 +136,6 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() - print("diff:", self.shared_state._mark_name, diff, flush=True) return diff > self.threshold diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index f7b0811d1a22..316fc4d6b722 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -701,7 +701,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -712,6 +712,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, From 41b0c473d2c8da7eef17abf3a1290878ec509f1b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 01:20:53 +0200 Subject: [PATCH 07/22] fix controlnet flux --- src/diffusers/models/controlnets/controlnet_flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 51c34b7fe965..04ab72e82a03 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ def forward( ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = () From 1f33ca276d064b258dc67b285fad5c6c80f43a98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 01:21:09 +0200 Subject: [PATCH 08/22] support flux, ltx i2v, ltx condition --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++++- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 3 ++- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 862c279cfaf3..a7195d3a679d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ def __call__( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,6 +917,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -932,6 +933,8 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index e7f3666cb2c7..e3b49cb673a3 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1061,7 +1061,7 @@ def __call__( self._num_timesteps = len(timesteps) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1090,6 +1090,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6c4214fe1b26..9ee96e6a3954 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -771,7 +771,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -783,6 +783,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, From c76e1cc17e451724848a96d7f3bbf6c8aa184267 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 21:52:33 +0200 Subject: [PATCH 09/22] update --- src/diffusers/hooks/first_block_cache.py | 5 +++-- src/diffusers/hooks/hooks.py | 17 ++++++++++------- src/diffusers/utils/torch_utils.py | 5 +++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 1293ded558f0..7863a1268843 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -18,6 +18,7 @@ import torch from ..utils import get_logger +from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry from .hooks import BaseMarkedState, HookRegistry, ModelHook @@ -71,7 +72,7 @@ def __init__(self, shared_state: FBCSharedBlockState, threshold: float): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -147,7 +148,7 @@ def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9e8128d0bb18..c42592783d91 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,6 +18,7 @@ import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name @@ -47,7 +48,7 @@ def get_current_state(self) -> "BaseMarkedState": self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) return self._state_cache[self._mark_name] - def mark_batch(self, name: str) -> None: + def mark_state(self, name: str) -> None: self._mark_name = name def reset(self, *args, **kwargs) -> None: @@ -59,7 +60,7 @@ def reset(self, *args, **kwargs) -> None: def __getattribute__(self, name): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -74,7 +75,7 @@ def __getattribute__(self, name): def __setattr__(self, name, value): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -164,11 +165,11 @@ def reset_state(self, module: torch.nn.Module): return module def _mark_state(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them. + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, BaseMarkedState): - attr.mark_batch(name) + attr.mark_state(name) return module @@ -283,9 +284,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -301,9 +303,10 @@ def _mark_state(self, name: str) -> None: if hook._is_stateful: hook._mark_state(self._module_ref, name) - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook._mark_state(name) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 46619ea717b4e47b1f8fa83ff6ab1e7ef1917c69 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 00:23:10 +0200 Subject: [PATCH 10/22] update --- src/diffusers/hooks/first_block_cache.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 7863a1268843..81cb3f1f9a08 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -77,7 +77,11 @@ def initialize_hook(self, module): def new_forward(self, module: torch.nn.Module, *args, **kwargs): outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) - original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + + if isinstance(outputs_if_skipped, tuple): + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + else: + original_hs = outputs_if_skipped output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) @@ -200,14 +204,14 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf head_block_name, head_block = remaining_blocks.pop(0) tail_block_name, tail_block = remaining_blocks.pop(-1) - logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") + logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") apply_fbc_head_block_hook(head_block, shared_state, config.threshold) for name, block in remaining_blocks: - logger.debug(f"Apply FBCBlockHook to '{name}'") + logger.debug(f"Applying FBCBlockHook to '{name}'") apply_fbc_block_hook(block, shared_state) - logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") + logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") apply_fbc_block_hook(tail_block, shared_state, is_tail=True) From ff5f2ee5059cc719234911b83a841c44b48ba0c7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 03:53:34 +0530 Subject: [PATCH 11/22] Update docs/source/en/api/cache.md --- docs/source/en/api/cache.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index a1d961cc2974..fc16a0ffcd4d 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -50,7 +50,7 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch pipe.to("cuda") # Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos. -# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05. +# Smaller values between 0.02-0.20 are recommended based on the model being used. The default value is 0.05. config = FirstBlockCacheConfig(threshold=0.07) pipe.transformer.enable_cache(config) ``` From ca715a9771fd87dbc5f7d7a4397c755fd6ed0b07 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Apr 2025 18:04:13 +0530 Subject: [PATCH 12/22] Update src/diffusers/hooks/hooks.py Co-authored-by: Dhruv Nair --- src/diffusers/hooks/hooks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index c42592783d91..17be6858740f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -58,14 +58,15 @@ def reset(self, *args, **kwargs) -> None: self._mark_name = None def __getattribute__(self, name): - if name in ( + direct_attrs = ( "get_current_state", "mark_state", "reset", "_init_args", "_init_kwargs", "_mark_name", - "_state_cache", + "_state_cache",) + if name in direct_attrs or _is_dunder_method(name): ) or _is_dunder_method(name): return object.__getattribute__(self, name) else: From 3dde07a64791b55192a4a88b44dbee9fb0a56c57 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 13:36:35 +0200 Subject: [PATCH 13/22] address review comments pt. 1 --- src/diffusers/hooks/_helpers.py | 74 +++++++++++++++++++++- src/diffusers/hooks/first_block_cache.py | 80 ++++++++++++------------ src/diffusers/hooks/hooks.py | 4 +- tests/pipelines/test_pipelines_common.py | 12 ++-- 4 files changed, 122 insertions(+), 48 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 253ca88059e5..9043ffc41838 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -15,8 +15,10 @@ from dataclasses import dataclass from typing import Any, Callable, Type +from ..models.attention import BasicTransformerBlock +from ..models.attention_processor import AttnProcessor2_0 from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock -from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock +from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_hunyuan_video import ( HunyuanVideoSingleTransformerBlock, @@ -29,6 +31,11 @@ from ..models.transformers.transformer_wan import WanTransformerBlock +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + @dataclass class TransformerBlockMetadata: skip_block_output_fn: Callable[[Any], Any] @@ -36,6 +43,20 @@ class TransformerBlockMetadata: return_encoder_hidden_states_index: int = None +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + class TransformerBlockRegistry: _registry = {} @@ -50,7 +71,35 @@ def get(cls, model_class: Type) -> TransformerBlockMetadata: return cls._registry[model_class] +def _register_attention_processors_metadata(): + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + def _register_transformer_blocks_metadata(): + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + # CogVideoX TransformerBlockRegistry.register( model_class=CogVideoXBlock, @@ -155,6 +204,27 @@ def _register_transformer_blocks_metadata(): # fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) if hidden_states is None and len(args) > 0: @@ -182,6 +252,7 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en return encoder_hidden_states, hidden_states +_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states _skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states @@ -196,4 +267,5 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en # fmt: on +_register_attention_processors_metadata() _register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 81cb3f1f9a08..6ce4015b6376 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -39,11 +39,11 @@ class FirstBlockCacheConfig: Args: threshold (`float`, defaults to `0.05`): The threshold to determine whether or not a forward pass through all layers of the model is required. A - higher threshold usually results in lower number of forward passes and faster inference, but might lead to - poorer generation quality. A lower threshold may not result in significant generation speedup. The - threshold is compared against the absmean difference of the residuals between the current and cached - outputs from the first transformer block. If the difference is below the threshold, the forward pass is - skipped. + higher threshold usually results in a forward pass through a lower number of layers and faster inference, + but might lead to poorer generation quality. A lower threshold may not result in significant generation + speedup. The threshold is compared against the absmean difference of the residuals between the current and + cached outputs from the first transformer block. If the difference is below the threshold, the forward pass + is skipped. """ threshold: float = 0.05 @@ -79,43 +79,45 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) if isinstance(outputs_if_skipped, tuple): - original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] else: - original_hs = outputs_if_skipped + original_hidden_states = outputs_if_skipped output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) if is_output_tuple: - hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states else: - hs_residual = output - original_hs + hidden_states_residual = output - original_hidden_states - hs, ehs = None, None - should_compute = self._should_compute_remaining_blocks(hs_residual) + hidden_states, encoder_hidden_states = None, None + should_compute = self._should_compute_remaining_blocks(hidden_states_residual) self.shared_state.should_compute = should_compute if not should_compute: # Apply caching if is_output_tuple: - hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + hidden_states = ( + self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + ) else: - hs = self.shared_state.tail_block_residuals[0] + output + hidden_states = self.shared_state.tail_block_residuals[0] + output if self._metadata.return_encoder_hidden_states_index is not None: assert is_output_tuple - ehs = ( + encoder_hidden_states = ( self.shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] ) if is_output_tuple: return_output = [None] * len(output) - return_output[self._metadata.return_hidden_states_index] = hs - return_output[self._metadata.return_encoder_hidden_states_index] = ehs + return_output[self._metadata.return_hidden_states_index] = hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states return_output = tuple(return_output) else: - return_output = hs + return_output = hidden_states output = return_output else: if is_output_tuple: @@ -125,7 +127,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): else: head_block_output = output self.shared_state.head_block_output = head_block_output - self.shared_state.head_block_residual = hs_residual + self.shared_state.head_block_residual = hidden_states_residual return output @@ -134,13 +136,13 @@ def reset_state(self, module): return module @torch.compiler.disable - def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: + def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: if self.shared_state.head_block_residual is None: return True - prev_hs_residual = self.shared_state.head_block_residual - hs_absmean = (hs_residual - prev_hs_residual).abs().mean() - prev_hs_mean = prev_hs_residual.abs().mean() - diff = (hs_absmean / prev_hs_mean).item() + prev_hidden_states_residual = self.shared_state.head_block_residual + absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() + prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() + diff = (absmean / prev_hidden_states_absmean).item() return diff > self.threshold @@ -159,35 +161,35 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) if not isinstance(outputs_if_skipped, tuple): outputs_if_skipped = (outputs_if_skipped,) - original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] - original_ehs = None + original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_encoder_hidden_states = None if self._metadata.return_encoder_hidden_states_index is not None: - original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] if self.shared_state.should_compute: output = self.fn_ref.original_forward(*args, **kwargs) if self.is_tail: - hs_residual, ehs_residual = None, None + hidden_states_residual = encoder_hidden_states_residual = None if isinstance(output, tuple): - hs_residual = ( + hidden_states_residual = ( output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] ) - ehs_residual = ( + encoder_hidden_states_residual = ( output[self._metadata.return_encoder_hidden_states_index] - self.shared_state.head_block_output[1] ) else: - hs_residual = output - self.shared_state.head_block_output - self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) + hidden_states_residual = output - self.shared_state.head_block_output + self.shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) return output output_count = len(outputs_if_skipped) if output_count == 1: - return_output = original_hs + return_output = original_hidden_states else: return_output = [None] * output_count - return_output[self._metadata.return_hidden_states_index] = original_hs - return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs + return_output[self._metadata.return_hidden_states_index] = original_hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states return return_output @@ -205,23 +207,23 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf tail_block_name, tail_block = remaining_blocks.pop(-1) logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") - apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + _apply_fbc_head_block_hook(head_block, shared_state, config.threshold) for name, block in remaining_blocks: logger.debug(f"Applying FBCBlockHook to '{name}'") - apply_fbc_block_hook(block, shared_state) + _apply_fbc_block_hook(block, shared_state) logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") - apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + _apply_fbc_block_hook(tail_block, shared_state, is_tail=True) -def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: +def _apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) hook = FBCHeadBlockHook(state, threshold) registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) -def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: +def _apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) hook = FBCBlockHook(state, is_tail) registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 17be6858740f..814529d0b275 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -65,9 +65,9 @@ def __getattribute__(self, name): "_init_args", "_init_kwargs", "_mark_name", - "_state_cache",) + "_state_cache", + ) if name in direct_attrs or _is_dunder_method(name): - ) or _is_dunder_method(name): return object.__getattribute__(self, name) else: current_state = BaseMarkedState.get_current_state(self) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index b0a850ae3f8e..8480eaedc185 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2780,12 +2780,12 @@ def run_forward(pipe): output = run_forward(pipe).flatten() image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) - assert np.allclose( - original_image_slice, image_slice_fbc_enabled, atol=expected_atol - ), "FirstBlockCache outputs should not differ much." - assert np.allclose( - original_image_slice, image_slice_fbc_disabled, atol=1e-4 - ), "Outputs from normal inference and after disabling cache should not differ." + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( + "FirstBlockCache outputs should not differ much." + ) + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. From f731664773d4dd79471b2f6befd4c3aaa3b4bb85 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 14:03:01 +0200 Subject: [PATCH 14/22] address review comments pt. 2 --- src/diffusers/hooks/first_block_cache.py | 4 +- src/diffusers/hooks/hooks.py | 63 +++++++++---------- src/diffusers/models/cache_utils.py | 2 +- .../pipelines/cogview4/pipeline_cogview4.py | 4 +- src/diffusers/pipelines/flux/pipeline_flux.py | 4 +- .../hunyuan_video/pipeline_hunyuan_video.py | 4 +- src/diffusers/pipelines/ltx/pipeline_ltx.py | 2 +- .../pipelines/ltx/pipeline_ltx_condition.py | 2 +- .../pipelines/ltx/pipeline_ltx_image2video.py | 2 +- src/diffusers/pipelines/wan/pipeline_wan.py | 4 +- 10 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 6ce4015b6376..b232e6465c4a 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -21,7 +21,7 @@ from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import BaseMarkedState, HookRegistry, ModelHook +from .hooks import ContextAwareState, HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name @@ -49,7 +49,7 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState(BaseMarkedState): +class FBCSharedBlockState(ContextAwareState): def __init__(self) -> None: super().__init__() diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 814529d0b275..16e80add846f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -31,61 +31,56 @@ def reset(self, *args, **kwargs) -> None: ) -class BaseMarkedState(BaseState): +class ContextAwareState(BaseState): def __init__(self, init_args=None, init_kwargs=None): super().__init__() self._init_args = init_args if init_args is not None else () self._init_kwargs = init_kwargs if init_kwargs is not None else {} - self._mark_name = None + self._current_context = None self._state_cache = {} - def get_current_state(self) -> "BaseMarkedState": - if self._mark_name is None: - # If no mark name is set, simply return a dummy object since we're not going to be using it + def get_state(self) -> "ContextAwareState": + if self._current_context is None: + # If no context is set, simply return a dummy object since we're not going to be using it return self - if self._mark_name not in self._state_cache.keys(): - self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) - return self._state_cache[self._mark_name] + if self._current_context not in self._state_cache.keys(): + self._state_cache[self._current_context] = ContextAwareState._create_state( + self.__class__, self._init_args, self._init_kwargs + ) + return self._state_cache[self._current_context] - def mark_state(self, name: str) -> None: - self._mark_name = name + def set_context(self, name: str) -> None: + self._current_context = name def reset(self, *args, **kwargs) -> None: for name, state in list(self._state_cache.items()): state.reset(*args, **kwargs) self._state_cache.pop(name) - self._mark_name = None + self._current_context = None + + @staticmethod + def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState": + return cls(*init_args, **init_kwargs) def __getattribute__(self, name): - direct_attrs = ( - "get_current_state", - "mark_state", - "reset", - "_init_args", - "_init_kwargs", - "_mark_name", - "_state_cache", - ) + # fmt: off + direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") + # fmt: on if name in direct_attrs or _is_dunder_method(name): return object.__getattribute__(self, name) else: - current_state = BaseMarkedState.get_current_state(self) + current_state = ContextAwareState.get_state(self) return object.__getattribute__(current_state, name) def __setattr__(self, name, value): - if name in ( - "get_current_state", - "mark_state", - "reset", - "_init_args", - "_init_kwargs", - "_mark_name", - "_state_cache", - ) or _is_dunder_method(name): + # fmt: off + direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") + # fmt: on + if name in direct_attrs or _is_dunder_method(name): object.__setattr__(self, name, value) else: - current_state = BaseMarkedState.get_current_state(self) + current_state = ContextAwareState.get_state(self) object.__setattr__(current_state, name, value) @@ -166,11 +161,11 @@ def reset_state(self, module: torch.nn.Module): return module def _mark_state(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. + # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) - if isinstance(attr, BaseMarkedState): - attr.mark_state(name) + if isinstance(attr, ContextAwareState): + attr.set_context(name) return module diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 6c4bcb301d70..7ff9f6b84c6c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -128,7 +128,7 @@ class _CacheContextManager: def __init__(self, model: CacheMixin): self.model = model - def mark_state(self, name: str) -> None: + def set_context(self, name: str) -> None: from ..hooks import HookRegistry if self.model.is_cache_enabled: diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 876ca922a670..46cb39e2a41d 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -619,7 +619,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - cc.mark_state("cond") + cc.set_context("cond") noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, @@ -633,7 +633,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: - cc.mark_state("uncond") + cc.set_context("uncond") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index a7195d3a679d..e9155fd640ee 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -917,7 +917,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -934,7 +934,7 @@ def __call__( if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - cc.mark_state("uncond") + cc.set_context("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index b36de61c02ef..2355b9010068 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -693,7 +693,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -706,7 +706,7 @@ def __call__( )[0] if do_true_cfg: - cc.mark_state("uncond") + cc.set_context("uncond") neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 2fa9fa53e8f0..3f3881e49f6d 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -719,7 +719,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index aa7e8eb5597c..5458b473b51e 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1105,7 +1105,7 @@ def __call__( if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index e9d2566a9bf1..1317acd8bafa 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -792,7 +792,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - cc.mark_state("cond_uncond") + cc.set_context("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index e38cb34a660e..78de45f4ab58 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -530,7 +530,7 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - cc.mark_state("cond") + cc.set_context("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -540,7 +540,7 @@ def __call__( )[0] if self.do_classifier_free_guidance: - cc.mark_state("uncond") + cc.set_context("uncond") noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, From 169bb0df9ce724f2adc5d11c76454541a515a685 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Apr 2025 14:24:08 +0200 Subject: [PATCH 15/22] cache context refacotr; address review pt. 3 --- src/diffusers/hooks/hooks.py | 8 ++-- src/diffusers/models/cache_utils.py | 20 +++------ .../pipelines/cogview4/pipeline_cogview4.py | 35 +++++++-------- src/diffusers/pipelines/flux/pipeline_flux.py | 44 +++++++++---------- .../hunyuan_video/pipeline_hunyuan_video.py | 38 ++++++++-------- src/diffusers/pipelines/ltx/pipeline_ltx.py | 28 ++++++------ .../pipelines/ltx/pipeline_ltx_condition.py | 22 +++++----- .../pipelines/ltx/pipeline_ltx_image2video.py | 28 ++++++------ src/diffusers/pipelines/wan/pipeline_wan.py | 28 ++++++------ 9 files changed, 122 insertions(+), 129 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 16e80add846f..4ca5761f75ab 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -160,7 +160,7 @@ def reset_state(self, module: torch.nn.Module): raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module - def _mark_state(self, module: torch.nn.Module, name: str) -> None: + def _set_context(self, module: torch.nn.Module, name: str) -> None: # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) @@ -293,18 +293,18 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry module._diffusers_hook = cls(module) return module._diffusers_hook - def _mark_state(self, name: str) -> None: + def _set_context(self, name: Optional[str] = None) -> None: for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: - hook._mark_state(self._module_ref, name) + hook._set_context(self._module_ref, name) for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): - module._diffusers_hook._mark_state(name) + module._diffusers_hook._set_context(name) def __repr__(self) -> str: registry_repr = "" diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 7ff9f6b84c6c..b251850cedbd 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -118,19 +118,13 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None: HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) @contextmanager - def _cache_context(self): + def cache_context(self, name: str): r"""Context manager that provides additional methods for cache management.""" - cache_context = _CacheContextManager(self) - yield cache_context - - -class _CacheContextManager: - def __init__(self, model: CacheMixin): - self.model = model - - def set_context(self, name: str) -> None: from ..hooks import HookRegistry - if self.model.is_cache_enabled: - registry = HookRegistry.check_if_exists_or_initialize(self.model) - registry._mark_state(name) + if self.is_cache_enabled: + registry = HookRegistry.check_if_exists_or_initialize(self) + registry._set_context(name) + yield + if self.is_cache_enabled: + registry._set_context(None) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 46cb39e2a41d..c3a6d7991b66 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -608,7 +608,7 @@ def __call__( transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -619,24 +619,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - cc.set_context("cond") - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - cc.set_context("uncond") - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, @@ -645,6 +631,19 @@ def __call__( return_dict=False, )[0] + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e9155fd640ee..cfd0eb271568 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ def __call__( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,35 +917,35 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - - cc.set_context("uncond") - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 2355b9010068..5e60b29c31d1 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -693,30 +693,30 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - cc.set_context("uncond") - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 3f3881e49f6d..81df4ca9382f 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -706,7 +706,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -719,19 +719,19 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 5458b473b51e..481ed0fd55b5 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1072,7 +1072,7 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1105,16 +1105,16 @@ def __call__( if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - video_coords=video_coords, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 1317acd8bafa..acd500f9fb90 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -778,7 +778,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -792,19 +792,19 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - cc.set_context("cond_uncond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 78de45f4ab58..59d07fa24f58 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -521,7 +521,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -530,24 +530,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - cc.set_context("cond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - cc.set_context("uncond") - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 From 0a44380a36f00d0ed0a67f0f8cb984d4c86ea8f7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 May 2025 12:14:24 +0200 Subject: [PATCH 16/22] address review comments --- src/diffusers/hooks/first_block_cache.py | 62 ++++++++++++------------ src/diffusers/hooks/hooks.py | 46 ++++-------------- 2 files changed, 41 insertions(+), 67 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index b232e6465c4a..31ee08c34d9d 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -21,7 +21,7 @@ from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import ContextAwareState, HookRegistry, ModelHook +from .hooks import BaseState, HookRegistry, ModelHook, StateManager logger = get_logger(__name__) # pylint: disable=invalid-name @@ -49,7 +49,7 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState(ContextAwareState): +class FBCSharedBlockState(BaseState): def __init__(self) -> None: super().__init__() @@ -66,8 +66,8 @@ def reset(self): class FBCHeadBlockHook(ModelHook): _is_stateful = True - def __init__(self, shared_state: FBCSharedBlockState, threshold: float): - self.shared_state = shared_state + def __init__(self, state_manager: StateManager, threshold: float): + self.state_manager = state_manager self.threshold = threshold self._metadata = None @@ -91,24 +91,24 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): else: hidden_states_residual = output - original_hidden_states + shared_state: FBCSharedBlockState = self.state_manager.get_state() hidden_states, encoder_hidden_states = None, None should_compute = self._should_compute_remaining_blocks(hidden_states_residual) - self.shared_state.should_compute = should_compute + shared_state.should_compute = should_compute if not should_compute: # Apply caching if is_output_tuple: hidden_states = ( - self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] ) else: - hidden_states = self.shared_state.tail_block_residuals[0] + output + hidden_states = shared_state.tail_block_residuals[0] + output if self._metadata.return_encoder_hidden_states_index is not None: assert is_output_tuple encoder_hidden_states = ( - self.shared_state.tail_block_residuals[1] - + output[self._metadata.return_encoder_hidden_states_index] + shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] ) if is_output_tuple: @@ -126,20 +126,21 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] else: head_block_output = output - self.shared_state.head_block_output = head_block_output - self.shared_state.head_block_residual = hidden_states_residual + shared_state.head_block_output = head_block_output + shared_state.head_block_residual = hidden_states_residual return output def reset_state(self, module): - self.shared_state.reset() + self.state_manager.reset() return module @torch.compiler.disable def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: - if self.shared_state.head_block_residual is None: + shared_state = self.state_manager.get_state() + if shared_state.head_block_residual is None: return True - prev_hidden_states_residual = self.shared_state.head_block_residual + prev_hidden_states_residual = shared_state.head_block_residual absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() diff = (absmean / prev_hidden_states_absmean).item() @@ -147,9 +148,9 @@ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) class FBCBlockHook(ModelHook): - def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + def __init__(self, state_manager: StateManager, is_tail: bool = False): super().__init__() - self.shared_state = shared_state + self.state_manager = state_manager self.is_tail = is_tail self._metadata = None @@ -166,21 +167,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self._metadata.return_encoder_hidden_states_index is not None: original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] - if self.shared_state.should_compute: + shared_state = self.state_manager.get_state() + + if shared_state.should_compute: output = self.fn_ref.original_forward(*args, **kwargs) if self.is_tail: hidden_states_residual = encoder_hidden_states_residual = None if isinstance(output, tuple): hidden_states_residual = ( - output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] + output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] ) encoder_hidden_states_residual = ( - output[self._metadata.return_encoder_hidden_states_index] - - self.shared_state.head_block_output[1] + output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] ) else: - hidden_states_residual = output - self.shared_state.head_block_output - self.shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) + hidden_states_residual = output - shared_state.head_block_output + shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) return output output_count = len(outputs_if_skipped) @@ -194,7 +196,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: - shared_state = FBCSharedBlockState() + state_manager = StateManager(FBCSharedBlockState, (), {}) remaining_blocks = [] for name, submodule in module.named_children(): @@ -207,23 +209,23 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf tail_block_name, tail_block = remaining_blocks.pop(-1) logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") - _apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) for name, block in remaining_blocks: logger.debug(f"Applying FBCBlockHook to '{name}'") - _apply_fbc_block_hook(block, shared_state) + _apply_fbc_block_hook(block, state_manager) logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") - _apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) -def _apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: +def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) - hook = FBCHeadBlockHook(state, threshold) + hook = FBCHeadBlockHook(state_manager, threshold) registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) -def _apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: +def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) - hook = FBCBlockHook(state, is_tail) + hook = FBCBlockHook(state_manager, is_tail) registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 4ca5761f75ab..3b39829fc5bf 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -31,23 +31,19 @@ def reset(self, *args, **kwargs) -> None: ) -class ContextAwareState(BaseState): - def __init__(self, init_args=None, init_kwargs=None): - super().__init__() - +class StateManager: + def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): + self._state_cls = state_cls self._init_args = init_args if init_args is not None else () self._init_kwargs = init_kwargs if init_kwargs is not None else {} - self._current_context = None self._state_cache = {} + self._current_context = None - def get_state(self) -> "ContextAwareState": + def get_state(self): if self._current_context is None: - # If no context is set, simply return a dummy object since we're not going to be using it - return self + raise ValueError("No context is set. Please set a context before retrieving the state.") if self._current_context not in self._state_cache.keys(): - self._state_cache[self._current_context] = ContextAwareState._create_state( - self.__class__, self._init_args, self._init_kwargs - ) + self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) return self._state_cache[self._current_context] def set_context(self, name: str) -> None: @@ -59,30 +55,6 @@ def reset(self, *args, **kwargs) -> None: self._state_cache.pop(name) self._current_context = None - @staticmethod - def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState": - return cls(*init_args, **init_kwargs) - - def __getattribute__(self, name): - # fmt: off - direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") - # fmt: on - if name in direct_attrs or _is_dunder_method(name): - return object.__getattribute__(self, name) - else: - current_state = ContextAwareState.get_state(self) - return object.__getattribute__(current_state, name) - - def __setattr__(self, name, value): - # fmt: off - direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") - # fmt: on - if name in direct_attrs or _is_dunder_method(name): - object.__setattr__(self, name, value) - else: - current_state = ContextAwareState.get_state(self) - object.__setattr__(current_state, name, value) - class ModelHook: r""" @@ -161,10 +133,10 @@ def reset_state(self, module: torch.nn.Module): return module def _set_context(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. + # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) - if isinstance(attr, ContextAwareState): + if isinstance(attr, StateManager): attr.set_context(name) return module From fb229b54bb53a14aef0871aa4f6643016f36742c Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 14 May 2025 14:19:48 +0200 Subject: [PATCH 17/22] metadata registration with decorators instead of centralized --- src/diffusers/hooks/_helpers.py | 271 ------------------ src/diffusers/hooks/first_block_cache.py | 26 +- src/diffusers/models/attention.py | 7 + src/diffusers/models/metadata.py | 63 ++++ .../transformers/cogvideox_transformer_3d.py | 7 + .../transformers/transformer_cogview4.py | 9 + .../models/transformers/transformer_flux.py | 13 + .../transformers/transformer_hunyuan_video.py | 31 ++ .../models/transformers/transformer_ltx.py | 7 + .../models/transformers/transformer_mochi.py | 7 + .../models/transformers/transformer_wan.py | 9 + 11 files changed, 163 insertions(+), 287 deletions(-) delete mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/models/metadata.py diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py deleted file mode 100644 index 9043ffc41838..000000000000 --- a/src/diffusers/hooks/_helpers.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Type - -from ..models.attention import BasicTransformerBlock -from ..models.attention_processor import AttnProcessor2_0 -from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock -from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock -from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock -from ..models.transformers.transformer_hunyuan_video import ( - HunyuanVideoSingleTransformerBlock, - HunyuanVideoTokenReplaceSingleTransformerBlock, - HunyuanVideoTokenReplaceTransformerBlock, - HunyuanVideoTransformerBlock, -) -from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock -from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock - - -@dataclass -class AttentionProcessorMetadata: - skip_processor_output_fn: Callable[[Any], Any] - - -@dataclass -class TransformerBlockMetadata: - skip_block_output_fn: Callable[[Any], Any] - return_hidden_states_index: int = None - return_encoder_hidden_states_index: int = None - - -class AttentionProcessorRegistry: - _registry = {} - - @classmethod - def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): - cls._registry[model_class] = metadata - - @classmethod - def get(cls, model_class: Type) -> AttentionProcessorMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] - - -class TransformerBlockRegistry: - _registry = {} - - @classmethod - def register(cls, model_class: Type, metadata: TransformerBlockMetadata): - cls._registry[model_class] = metadata - - @classmethod - def get(cls, model_class: Type) -> TransformerBlockMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] - - -def _register_attention_processors_metadata(): - # AttnProcessor2_0 - AttentionProcessorRegistry.register( - model_class=AttnProcessor2_0, - metadata=AttentionProcessorMetadata( - skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, - ), - ) - - # CogView4AttnProcessor - AttentionProcessorRegistry.register( - model_class=CogView4AttnProcessor, - metadata=AttentionProcessorMetadata( - skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, - ), - ) - - -def _register_transformer_blocks_metadata(): - # BasicTransformerBlock - TransformerBlockRegistry.register( - model_class=BasicTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - # CogVideoX - TransformerBlockRegistry.register( - model_class=CogVideoXBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # CogView4 - TransformerBlockRegistry.register( - model_class=CogView4TransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # Flux - TransformerBlockRegistry.register( - model_class=FluxTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) - TransformerBlockRegistry.register( - model_class=FluxSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, - return_hidden_states_index=1, - return_encoder_hidden_states_index=0, - ), - ) - - # HunyuanVideo - TransformerBlockRegistry.register( - model_class=HunyuanVideoTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoTokenReplaceTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - TransformerBlockRegistry.register( - model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # LTXVideo - TransformerBlockRegistry.register( - model_class=LTXVideoTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - # Mochi - TransformerBlockRegistry.register( - model_class=MochiTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=1, - ), - ) - - # Wan - TransformerBlockRegistry.register( - model_class=WanTransformerBlock, - metadata=TransformerBlockMetadata( - skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, - ), - ) - - -# fmt: off -def _skip_attention___ret___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - return hidden_states - - -def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return hidden_states, encoder_hidden_states - - -_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states -_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - return hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return hidden_states, encoder_hidden_states - - -def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): - hidden_states = kwargs.get("hidden_states", None) - encoder_hidden_states = kwargs.get("encoder_hidden_states", None) - if hidden_states is None and len(args) > 0: - hidden_states = args[0] - if encoder_hidden_states is None and len(args) > 1: - encoder_hidden_states = args[1] - return encoder_hidden_states, hidden_states - - -_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states -_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states -# fmt: on - - -_register_attention_processors_metadata() -_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 31ee08c34d9d..a7a415ca51fb 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -17,10 +17,10 @@ import torch +from ..models.metadata import TransformerBlockRegistry from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS -from ._helpers import TransformerBlockRegistry from .hooks import BaseState, HookRegistry, ModelHook, StateManager @@ -76,12 +76,7 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) - - if isinstance(outputs_if_skipped, tuple): - original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] - else: - original_hidden_states = outputs_if_skipped + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) @@ -92,7 +87,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): hidden_states_residual = output - original_hidden_states shared_state: FBCSharedBlockState = self.state_manager.get_state() - hidden_states, encoder_hidden_states = None, None + hidden_states = encoder_hidden_states = None should_compute = self._should_compute_remaining_blocks(hidden_states_residual) shared_state.should_compute = should_compute @@ -159,13 +154,12 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) - if not isinstance(outputs_if_skipped, tuple): - outputs_if_skipped = (outputs_if_skipped,) - original_hidden_states = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) original_encoder_hidden_states = None if self._metadata.return_encoder_hidden_states_index is not None: - original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) shared_state = self.state_manager.get_state() @@ -185,13 +179,13 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) return output - output_count = len(outputs_if_skipped) - if output_count == 1: + if original_encoder_hidden_states is None: return_output = original_hidden_states else: - return_output = [None] * output_count + return_output = [None, None] return_output[self._metadata.return_hidden_states_index] = original_hidden_states return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return_output = tuple(return_output) return return_output diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 93b11c2b43f0..dfcfc273045c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,6 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding +from .metadata import TransformerBlockMetadata, TransformerBlockRegistry from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -258,6 +259,12 @@ def forward( @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. diff --git a/src/diffusers/models/metadata.py b/src/diffusers/models/metadata.py new file mode 100644 index 000000000000..9b13e52fc01c --- /dev/null +++ b/src/diffusers/models/metadata.py @@ -0,0 +1,63 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Dict, Type + + +@dataclass +class TransformerBlockMetadata: + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + _cls: Type = None + _cached_parameter_indices: Dict[str, int] = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + if identifier in kwargs: + return kwargs[identifier] + if self._cached_parameter_indices is not None: + return args[self._cached_parameter_indices[identifier]] + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + if identifier not in self._cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + index = self._cached_parameter_indices[identifier] + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + return args[index] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, metadata: TransformerBlockMetadata): + def inner(model_class: Type): + metadata._cls = model_class + cls._registry[model_class] = metadata + return model_class + + return inner + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 6b4f38dc04a1..d3e596b1af4a 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -26,6 +26,7 @@ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -35,6 +36,12 @@ @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class CogVideoXBlock(nn.Module): r""" Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index aef368f91ac0..c3d40b874941 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -21,10 +21,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -453,6 +455,13 @@ def __call__( return hidden_states, encoder_hidden_states +@maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class CogView4TransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index c9abe06b42fd..f66d5f982bd1 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -34,6 +34,7 @@ ) from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -43,6 +44,12 @@ @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) +) class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): super().__init__() @@ -109,6 +116,12 @@ def forward( @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ) +) class FluxTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index d9100b2f54d0..6e5b107f9a14 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -33,6 +33,7 @@ Timesteps, get_1d_rotary_pos_embed, ) +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm @@ -310,6 +311,12 @@ def forward( return conditioning, token_replace_emb +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, @@ -489,6 +496,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, @@ -565,6 +578,12 @@ def forward( return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, @@ -644,6 +663,12 @@ def forward( return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module): def __init__( self, @@ -724,6 +749,12 @@ def forward( return hidden_states, encoder_hidden_states +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class HunyuanVideoTokenReplaceTransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ae2418098f6..8a8409f1bff7 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -28,6 +28,7 @@ from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm @@ -196,6 +197,12 @@ def forward( @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class LTXVideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index e6532f080d72..21481426010d 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -27,6 +27,7 @@ from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm @@ -116,6 +117,12 @@ def forward( @maybe_allow_in_graph +@TransformerBlockRegistry.register( + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ) +) class MochiTransformerBlock(nn.Module): r""" Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index c78d72dc4a2c..ec607c512633 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -22,10 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -219,6 +221,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs +@maybe_allow_in_graph +@TransformerBlockRegistry.register( + TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ) +) class WanTransformerBlock(nn.Module): def __init__( self, From 367fdef96db8a4448930071f05a293b859cf091f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 May 2025 21:41:45 +0200 Subject: [PATCH 18/22] support cogvideox --- .../pipelines/cogvideo/pipeline_cogvideox.py | 17 +++++++++-------- .../pipeline_cogvideox_fun_control.py | 17 +++++++++-------- .../pipeline_cogvideox_image2video.py | 19 ++++++++++--------- .../pipeline_cogvideox_video2video.py | 17 +++++++++-------- 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 99ae9025cd3e..4636844cd89b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -718,14 +718,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index e37574ec9cb2..121c4280f0bd 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -784,14 +784,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 59d7c4cad547..25e62cb2b34e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -831,15 +831,16 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - ofs=ofs_emb, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index c4dc7e574f7e..46deec9b99fa 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -799,14 +799,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance From 495fddb8aec9d5c41ca5b45ce86c0c6301f17e16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 May 2025 21:42:19 +0200 Subject: [PATCH 19/22] support mochi --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index d1f88b02c5cc..dc83d3bc1da1 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -671,14 +671,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # Mochi CFG + Sampling runs in FP32 noise_pred = noise_pred.to(torch.float32) From 153cf0c3934ab2639f59d0fbcb716fd964d1403b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 16 May 2025 12:33:52 +0200 Subject: [PATCH 20/22] fix --- src/diffusers/models/cache_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index b251850cedbd..f646701ba094 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -122,9 +122,9 @@ def cache_context(self, name: str): r"""Context manager that provides additional methods for cache management.""" from ..hooks import HookRegistry - if self.is_cache_enabled: - registry = HookRegistry.check_if_exists_or_initialize(self) - registry._set_context(name) + registry = HookRegistry.check_if_exists_or_initialize(self) + registry._set_context(name) + yield - if self.is_cache_enabled: - registry._set_context(None) + + registry._set_context(None) From a5fe2bd4fd825bf64c2c7a717f031bfda700e36c Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 16 May 2025 12:37:01 +0200 Subject: [PATCH 21/22] remove unused function --- src/diffusers/hooks/hooks.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 3b39829fc5bf..4b8810ae68e2 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -289,7 +289,3 @@ def __repr__(self) -> str: if i < len(self._hook_order) - 1: registry_repr += "\n" return f"HookRegistry(\n{registry_repr}\n)" - - -def _is_dunder_method(name: str) -> bool: - return name.startswith("__") and name.endswith("__") and name in dir(object) From b8317da20f023fe6022d34e8ca8bb91264596ea5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 16 May 2025 12:53:36 +0200 Subject: [PATCH 22/22] remove central registry based on review --- src/diffusers/hooks/first_block_cache.py | 17 ++++++++++--- src/diffusers/models/attention.py | 4 ++-- src/diffusers/models/metadata.py | 24 ++++++------------- .../transformers/cogvideox_transformer_3d.py | 4 ++-- .../transformers/transformer_cogview4.py | 4 ++-- .../models/transformers/transformer_flux.py | 6 ++--- .../transformers/transformer_hunyuan_video.py | 12 +++++----- .../models/transformers/transformer_ltx.py | 4 ++-- .../models/transformers/transformer_mochi.py | 4 ++-- .../models/transformers/transformer_wan.py | 4 ++-- 10 files changed, 42 insertions(+), 41 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index a7a415ca51fb..e2e27048cc61 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -17,7 +17,6 @@ import torch -from ..models.metadata import TransformerBlockRegistry from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS @@ -72,7 +71,13 @@ def __init__(self, state_manager: StateManager, threshold: float): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + unwrapped_module = unwrap_module(module) + if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"): + raise ValueError( + f"Module {unwrapped_module} does not have any registered metadata. " + "Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`." + ) + self._metadata = unwrapped_module._diffusers_transformer_block_metadata return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -150,7 +155,13 @@ def __init__(self, state_manager: StateManager, is_tail: bool = False): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + unwrapped_module = unwrap_module(module) + if not hasattr(unwrapped_module, "_diffusers_transformer_block_metadata"): + raise ValueError( + f"Module {unwrapped_module} does not have any registered metadata. " + "Make sure to register the metadata using `diffusers.models.metadata.register_transformer_block`." + ) + self._metadata = unwrapped_module._diffusers_transformer_block_metadata return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index dfcfc273045c..b2bc08beff1f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding -from .metadata import TransformerBlockMetadata, TransformerBlockRegistry +from .metadata import TransformerBlockMetadata, register_transformer_block from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -259,7 +259,7 @@ def forward( @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, diff --git a/src/diffusers/models/metadata.py b/src/diffusers/models/metadata.py index 9b13e52fc01c..6da190ac307f 100644 --- a/src/diffusers/models/metadata.py +++ b/src/diffusers/models/metadata.py @@ -44,20 +44,10 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None) return args[index] -class TransformerBlockRegistry: - _registry = {} - - @classmethod - def register(cls, metadata: TransformerBlockMetadata): - def inner(model_class: Type): - metadata._cls = model_class - cls._registry[model_class] = metadata - return model_class - - return inner - - @classmethod - def get(cls, model_class: Type) -> TransformerBlockMetadata: - if model_class not in cls._registry: - raise ValueError(f"Model class {model_class} not registered.") - return cls._registry[model_class] +def register_transformer_block(metadata: TransformerBlockMetadata): + def inner(model_class: Type): + metadata._cls = model_class + model_class._diffusers_transformer_block_metadata = metadata + return model_class + + return inner diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index d3e596b1af4a..4561cbf505c9 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -26,7 +26,7 @@ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero @@ -36,7 +36,7 @@ @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index c3d40b874941..8103b9dd839c 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -26,7 +26,7 @@ from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous @@ -456,7 +456,7 @@ def __call__( @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f66d5f982bd1..3be0ba9d16dd 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -34,7 +34,7 @@ ) from ..cache_utils import CacheMixin from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -44,7 +44,7 @@ @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=1, return_encoder_hidden_states_index=0, @@ -116,7 +116,7 @@ def forward( @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=1, return_encoder_hidden_states_index=0, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 6e5b107f9a14..1554ac129bfc 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -33,7 +33,7 @@ Timesteps, get_1d_rotary_pos_embed, ) -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm @@ -311,7 +311,7 @@ def forward( return conditioning, token_replace_emb -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, @@ -496,7 +496,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -578,7 +578,7 @@ def forward( return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -663,7 +663,7 @@ def forward( return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, @@ -749,7 +749,7 @@ def forward( return hidden_states, encoder_hidden_states -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 8a8409f1bff7..042881524e77 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -28,7 +28,7 @@ from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm @@ -197,7 +197,7 @@ def forward( @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 21481426010d..f875103c2699 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -27,7 +27,7 @@ from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm @@ -117,7 +117,7 @@ def forward( @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( metadata=TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=1, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index ec607c512633..4ab26b90b326 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -27,7 +27,7 @@ from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed -from ..metadata import TransformerBlockMetadata, TransformerBlockRegistry +from ..metadata import TransformerBlockMetadata, register_transformer_block from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -222,7 +222,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @maybe_allow_in_graph -@TransformerBlockRegistry.register( +@register_transformer_block( TransformerBlockMetadata( return_hidden_states_index=0, return_encoder_hidden_states_index=None,