@@ -281,6 +281,16 @@ def do_classifier_free_guidance(self):
281
281
def num_timesteps (self ):
282
282
return self ._num_timesteps
283
283
284
+ def get_timestep_ratio_conditioning (self , t , alphas_cumprod ):
285
+ s = torch .tensor ([0.008 ])
286
+ clamp_range = [0 , 1 ]
287
+ min_var = torch .cos (s / (1 + s ) * torch .pi * 0.5 ) ** 2
288
+ var = alphas_cumprod [t ]
289
+ var = var .clamp (* clamp_range )
290
+ s , min_var = s .to (var .device ), min_var .to (var .device )
291
+ ratio = (((var * min_var ) ** 0.5 ).acos () / (torch .pi * 0.5 )) * (1 + s ) - s
292
+ return ratio
293
+
284
294
@torch .no_grad ()
285
295
@replace_example_docstring (EXAMPLE_DOC_STRING )
286
296
def __call__ (
@@ -434,10 +444,30 @@ def __call__(
434
444
batch_size , image_embeddings , num_images_per_prompt , dtype , device , generator , latents , self .scheduler
435
445
)
436
446
447
+ if isinstance (self .scheduler , DDPMWuerstchenScheduler ):
448
+ timesteps = timesteps [:- 1 ]
449
+ else :
450
+ if hasattr (self .scheduler .config , "clip_sample" ) and self .scheduler .config .clip_sample :
451
+ self .scheduler .config .clip_sample = False # disample sample clipping
452
+ logger .warning (" set `clip_sample` to be False" )
453
+
437
454
# 6. Run denoising loop
438
- self ._num_timesteps = len (timesteps [:- 1 ])
439
- for i , t in enumerate (self .progress_bar (timesteps [:- 1 ])):
440
- timestep_ratio = t .expand (latents .size (0 )).to (dtype )
455
+ if hasattr (self .scheduler , "betas" ):
456
+ alphas = 1.0 - self .scheduler .betas
457
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
458
+ else :
459
+ alphas_cumprod = []
460
+
461
+ self ._num_timesteps = len (timesteps )
462
+ for i , t in enumerate (self .progress_bar (timesteps )):
463
+ if not isinstance (self .scheduler , DDPMWuerstchenScheduler ):
464
+ if len (alphas_cumprod ) > 0 :
465
+ timestep_ratio = self .get_timestep_ratio_conditioning (t .long ().cpu (), alphas_cumprod )
466
+ timestep_ratio = timestep_ratio .expand (latents .size (0 )).to (dtype ).to (device )
467
+ else :
468
+ timestep_ratio = t .float ().div (self .scheduler .timesteps [- 1 ]).expand (latents .size (0 )).to (dtype )
469
+ else :
470
+ timestep_ratio = t .expand (latents .size (0 )).to (dtype )
441
471
442
472
# 7. Denoise latents
443
473
predicted_latents = self .decoder (
@@ -454,6 +484,8 @@ def __call__(
454
484
predicted_latents = torch .lerp (predicted_latents_uncond , predicted_latents_text , self .guidance_scale )
455
485
456
486
# 9. Renoise latents to next timestep
487
+ if not isinstance (self .scheduler , DDPMWuerstchenScheduler ):
488
+ timestep_ratio = t
457
489
latents = self .scheduler .step (
458
490
model_output = predicted_latents ,
459
491
timestep = timestep_ratio ,
0 commit comments