From 6c40c186ec7649f7498a58ee61d62a0db81e2bb1 Mon Sep 17 00:00:00 2001 From: Robey Holderith Date: Sat, 26 Jul 2025 08:56:01 -0700 Subject: [PATCH 1/2] Add cache_threshold to samplers.json -- use it if block caching is enabled. --- ai_diffusion/api.py | 1 + ai_diffusion/comfy_workflow.py | 4 ++-- ai_diffusion/style.py | 1 + ai_diffusion/workflow.py | 17 +++++++++-------- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index 06b4f48de..5ac0fa821 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -72,6 +72,7 @@ class SamplingInput: scheduler: str cfg_scale: float total_steps: int + cache_threshold: float = 0 start_step: int = 0 seed: int = 0 diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 8f8a05eed..53d538802 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -1067,13 +1067,13 @@ def estimate_pose(self, image: Output, resolution: int): mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx" return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls) - def apply_first_block_cache(self, model: Output, arch: Arch): + def apply_first_block_cache(self, model: Output, arch: Arch, threshold: float): return self.add( "ApplyFBCacheOnModel", 1, model=model, object_to_patch="diffusion_model", - residual_diff_threshold=0.2 if arch.is_sdxl_like else 0.12, + residual_diff_threshold=threshold or 0.2 if arch.is_sdxl_like else 0.12, start=0.0, end=1.0, max_consecutive_cache_hits=-1, diff --git a/ai_diffusion/style.py b/ai_diffusion/style.py index 955f94a72..a4bb1186d 100644 --- a/ai_diffusion/style.py +++ b/ai_diffusion/style.py @@ -321,6 +321,7 @@ class SamplerPreset(NamedTuple): lora: str | None = None minimum_steps: int = 4 hidden: bool = False + cache_threshold: float = 0 class SamplerPresets: diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 6d6abb87a..8f88c50e8 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -45,6 +45,7 @@ def _sampling_from_style(style: Style, strength: float, is_live: bool): scheduler=preset.scheduler, cfg_scale=cfg or preset.cfg, total_steps=max_steps, + cache_threshold=preset.cache_threshold or None ) if strength < 1.0: result.total_steps, result.start_step = apply_strength(strength, max_steps, min_steps) @@ -87,7 +88,7 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None): return params -def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels): +def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, sampling: SamplingInput, models: ClientModels): arch = checkpoint.version model_info = models.checkpoints.get(checkpoint.checkpoint) if model_info is None: @@ -130,7 +131,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod vae = w.load_vae(models.for_arch(arch).vae) if checkpoint.dynamic_caching and (arch in [Arch.flux, Arch.sd3] or arch.is_sdxl_like): - model = w.apply_first_block_cache(model, arch) + model = w.apply_first_block_cache(model, arch, sampling.cache_threshold) for lora in checkpoint.loras: model, clip = w.load_lora(model, clip, lora.name, lora.strength, lora.strength) @@ -750,7 +751,7 @@ def generate( misc: MiscParams, models: ModelDict, ): - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) @@ -862,7 +863,7 @@ def inpaint( checkpoint.dynamic_caching = False # doesn't seem to work with Flux fill model sampling.cfg_scale = 30 # set Flux guidance to 30 (typical values don't work well) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = w.differential_diffusion(model) model_orig = copy(model) @@ -991,7 +992,7 @@ def refine( misc: MiscParams, models: ModelDict, ): - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) @@ -1028,7 +1029,7 @@ def refine_region( ): extent = ScaledExtent.from_input(images.extent) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = w.differential_diffusion(model) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) @@ -1179,7 +1180,7 @@ def upscale_tiled( extent.initial, extent.desired.width, sampling.denoise_strength ) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) in_image = w.load_image(image) @@ -1298,7 +1299,7 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None = is_live = node.input("sampler_preset", "auto") == "live" checkpoint_input = style.get_models(models.checkpoints.keys()) sampling = _sampling_from_style(style, 1.0, is_live) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, sampling, models) outputs[node.output(0)] = model outputs[node.output(1)] = clip.model outputs[node.output(2)] = vae From 4aeaadd91fd4993f6ae21e9a01b8543edd9293e7 Mon Sep 17 00:00:00 2001 From: Robey Holderith Date: Sat, 26 Jul 2025 09:01:03 -0700 Subject: [PATCH 2/2] ruff formatting --- ai_diffusion/workflow.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 8f88c50e8..cdc02c740 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -45,7 +45,7 @@ def _sampling_from_style(style: Style, strength: float, is_live: bool): scheduler=preset.scheduler, cfg_scale=cfg or preset.cfg, total_steps=max_steps, - cache_threshold=preset.cache_threshold or None + cache_threshold=preset.cache_threshold or None, ) if strength < 1.0: result.total_steps, result.start_step = apply_strength(strength, max_steps, min_steps) @@ -88,7 +88,9 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None): return params -def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, sampling: SamplingInput, models: ClientModels): +def load_checkpoint_with_lora( + w: ComfyWorkflow, checkpoint: CheckpointInput, sampling: SamplingInput, models: ClientModels +): arch = checkpoint.version model_info = models.checkpoints.get(checkpoint.checkpoint) if model_info is None: