diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index d4c40cb9b..b9f930ff2 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -72,6 +72,7 @@ class SamplingInput: scheduler: str cfg_scale: float total_steps: int + cache_threshold: float = 0 start_step: int = 0 seed: int = 0 diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index a1091e849..eef7e9104 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -1075,13 +1075,13 @@ def estimate_pose(self, image: Output, resolution: int): mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx" return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls) - def apply_first_block_cache(self, model: Output, arch: Arch): + def apply_first_block_cache(self, model: Output, arch: Arch, threshold: float): return self.add( "ApplyFBCacheOnModel", 1, model=model, object_to_patch="diffusion_model", - residual_diff_threshold=0.2 if arch.is_sdxl_like else 0.12, + residual_diff_threshold=threshold or 0.2 if arch.is_sdxl_like else 0.12, start=0.0, end=1.0, max_consecutive_cache_hits=-1, diff --git a/ai_diffusion/style.py b/ai_diffusion/style.py index 955f94a72..a4bb1186d 100644 --- a/ai_diffusion/style.py +++ b/ai_diffusion/style.py @@ -321,6 +321,7 @@ class SamplerPreset(NamedTuple): lora: str | None = None minimum_steps: int = 4 hidden: bool = False + cache_threshold: float = 0 class SamplerPresets: diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index e6462649f..5be39bf08 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -45,6 +45,7 @@ def _sampling_from_style(style: Style, strength: float, is_live: bool): scheduler=preset.scheduler, cfg_scale=cfg or preset.cfg, total_steps=max_steps, + cache_threshold=preset.cache_threshold or None, ) if strength < 1.0: result.total_steps, result.start_step = apply_strength(strength, max_steps, min_steps) @@ -87,7 +88,9 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None): return params -def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels): +def load_checkpoint_with_lora( + w: ComfyWorkflow, checkpoint: CheckpointInput, sampling: SamplingInput, models: ClientModels +): arch = checkpoint.version model_info = models.checkpoints.get(checkpoint.checkpoint) if model_info is None: @@ -133,7 +136,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod vae = w.load_vae(models.for_arch(arch).vae) if checkpoint.dynamic_caching and (arch in [Arch.flux, Arch.sd3] or arch.is_sdxl_like): - model = w.apply_first_block_cache(model, arch) + model = w.apply_first_block_cache(model, arch, sampling.cache_threshold) for lora in checkpoint.loras: model, clip = w.load_lora(model, clip, lora.name, lora.strength, lora.strength) @@ -753,7 +756,7 @@ def generate( misc: MiscParams, models: ModelDict, ): - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) @@ -865,7 +868,7 @@ def inpaint( checkpoint.dynamic_caching = False # doesn't seem to work with Flux fill model sampling.cfg_scale = 30 # set Flux guidance to 30 (typical values don't work well) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = w.differential_diffusion(model) model_orig = copy(model) @@ -994,7 +997,7 @@ def refine( misc: MiscParams, models: ModelDict, ): - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) @@ -1031,7 +1034,7 @@ def refine_region( ): extent = ScaledExtent.from_input(images.extent) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = w.differential_diffusion(model) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) @@ -1182,7 +1185,7 @@ def upscale_tiled( extent.initial, extent.desired.width, sampling.denoise_strength ) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint, sampling, models.all) model = apply_ip_adapter(w, model, cond.control, models) in_image = w.load_image(image) @@ -1301,7 +1304,7 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None = is_live = node.input("sampler_preset", "auto") == "live" checkpoint_input = style.get_models(models.checkpoints.keys()) sampling = _sampling_from_style(style, 1.0, is_live) - model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, sampling, models) outputs[node.output(0)] = model outputs[node.output(1)] = clip.model outputs[node.output(2)] = vae