Skip to content

Commit 16a3dad

Browse files
authored
Fix StableDiffusionXLPAGInpaintPipeline (#9128)
1 parent 21682ba commit 16a3dad

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
955955
if "enable_pag" in kwargs:
956956
enable_pag = kwargs.pop("enable_pag")
957957
if enable_pag:
958-
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
958+
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
959+
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
959960

960961
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
961962

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,14 @@ def denoising_value_valid(dnv):
14711471
generator,
14721472
self.do_classifier_free_guidance,
14731473
)
1474+
if self.do_perturbed_attention_guidance:
1475+
if self.do_classifier_free_guidance:
1476+
mask, _ = mask.chunk(2)
1477+
masked_image_latents, _ = masked_image_latents.chunk(2)
1478+
mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
1479+
masked_image_latents = self._prepare_perturbed_attention_guidance(
1480+
masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
1481+
)
14741482

14751483
# 8. Check that sizes of mask, masked image and latents match
14761484
if num_channels_unet == 9:
@@ -1659,10 +1667,10 @@ def denoising_value_valid(dnv):
16591667

16601668
if num_channels_unet == 4:
16611669
init_latents_proper = image_latents
1662-
if self.do_classifier_free_guidance:
1663-
init_mask, _ = mask.chunk(2)
1670+
if self.do_perturbed_attention_guidance:
1671+
init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
16641672
else:
1665-
init_mask = mask
1673+
init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
16661674

16671675
if i < len(timesteps) - 1:
16681676
noise_timestep = timesteps[i + 1]

0 commit comments

Comments
 (0)