Skip to content

Commit 21682ba

Browse files
authored
Custom sampler support for Stable Cascade Decoder (#9132)
Custom sampler support Stable Cascade Decoder
1 parent 214990e commit 21682ba

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ def do_classifier_free_guidance(self):
281281
def num_timesteps(self):
282282
return self._num_timesteps
283283

284+
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
285+
s = torch.tensor([0.008])
286+
clamp_range = [0, 1]
287+
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
288+
var = alphas_cumprod[t]
289+
var = var.clamp(*clamp_range)
290+
s, min_var = s.to(var.device), min_var.to(var.device)
291+
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
292+
return ratio
293+
284294
@torch.no_grad()
285295
@replace_example_docstring(EXAMPLE_DOC_STRING)
286296
def __call__(
@@ -434,10 +444,30 @@ def __call__(
434444
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
435445
)
436446

447+
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
448+
timesteps = timesteps[:-1]
449+
else:
450+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
451+
self.scheduler.config.clip_sample = False # disample sample clipping
452+
logger.warning(" set `clip_sample` to be False")
453+
437454
# 6. Run denoising loop
438-
self._num_timesteps = len(timesteps[:-1])
439-
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
440-
timestep_ratio = t.expand(latents.size(0)).to(dtype)
455+
if hasattr(self.scheduler, "betas"):
456+
alphas = 1.0 - self.scheduler.betas
457+
alphas_cumprod = torch.cumprod(alphas, dim=0)
458+
else:
459+
alphas_cumprod = []
460+
461+
self._num_timesteps = len(timesteps)
462+
for i, t in enumerate(self.progress_bar(timesteps)):
463+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
464+
if len(alphas_cumprod) > 0:
465+
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
466+
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
467+
else:
468+
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
469+
else:
470+
timestep_ratio = t.expand(latents.size(0)).to(dtype)
441471

442472
# 7. Denoise latents
443473
predicted_latents = self.decoder(
@@ -454,6 +484,8 @@ def __call__(
454484
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
455485

456486
# 9. Renoise latents to next timestep
487+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
488+
timestep_ratio = t
457489
latents = self.scheduler.step(
458490
model_output=predicted_latents,
459491
timestep=timestep_ratio,

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def num_timesteps(self):
353353
return self._num_timesteps
354354

355355
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
356-
s = torch.tensor([0.003])
356+
s = torch.tensor([0.008])
357357
clamp_range = [0, 1]
358358
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
359359
var = alphas_cumprod[t]
@@ -557,7 +557,7 @@ def __call__(
557557
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
558558
timesteps = timesteps[:-1]
559559
else:
560-
if self.scheduler.config.clip_sample:
560+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
561561
self.scheduler.config.clip_sample = False # disample sample clipping
562562
logger.warning(" set `clip_sample` to be False")
563563
# 6. Run denoising loop

0 commit comments

Comments
 (0)