@@ -1471,6 +1471,14 @@ def denoising_value_valid(dnv):
1471
1471
generator ,
1472
1472
self .do_classifier_free_guidance ,
1473
1473
)
1474
+ if self .do_perturbed_attention_guidance :
1475
+ if self .do_classifier_free_guidance :
1476
+ mask , _ = mask .chunk (2 )
1477
+ masked_image_latents , _ = masked_image_latents .chunk (2 )
1478
+ mask = self ._prepare_perturbed_attention_guidance (mask , mask , self .do_classifier_free_guidance )
1479
+ masked_image_latents = self ._prepare_perturbed_attention_guidance (
1480
+ masked_image_latents , masked_image_latents , self .do_classifier_free_guidance
1481
+ )
1474
1482
1475
1483
# 8. Check that sizes of mask, masked image and latents match
1476
1484
if num_channels_unet == 9 :
@@ -1659,10 +1667,10 @@ def denoising_value_valid(dnv):
1659
1667
1660
1668
if num_channels_unet == 4 :
1661
1669
init_latents_proper = image_latents
1662
- if self .do_classifier_free_guidance :
1663
- init_mask , _ = mask .chunk (2 )
1670
+ if self .do_perturbed_attention_guidance :
1671
+ init_mask , * _ = mask . chunk ( 3 ) if self . do_classifier_free_guidance else mask .chunk (2 )
1664
1672
else :
1665
- init_mask = mask
1673
+ init_mask , * _ = mask . chunk ( 2 ) if self . do_classifier_free_guidance else mask
1666
1674
1667
1675
if i < len (timesteps ) - 1 :
1668
1676
noise_timestep = timesteps [i + 1 ]
0 commit comments