Skip to content

Commit 19fad2b

Browse files
gameofdimensiongithub-actions[bot]yiyixuxuasomoza
authored andcommitted
enable flux pipeline compatible with unipc and dpm-solver (huggingface#11908)
* Update pipeline_flux.py have flux pipeline work with unipc/dpm schedulers * clean code * Update scheduling_dpmsolver_multistep.py * Update scheduling_unipc_multistep.py * Update pipeline_flux.py * Update scheduling_deis_multistep.py * Update scheduling_dpmsolver_singlestep.py * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Álvaro Somoza <[email protected]>
1 parent 0e8dcb3 commit 19fad2b

File tree

5 files changed

+30
-2
lines changed

5 files changed

+30
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@ def __call__(
840840

841841
# 5. Prepare timesteps
842842
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
843+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
844+
sigmas = None
843845
image_seq_len = latents.shape[1]
844846
mu = calculate_shift(
845847
image_seq_len,

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def __init__(
153153
flow_shift: Optional[float] = 1.0,
154154
timestep_spacing: str = "linspace",
155155
steps_offset: int = 0,
156+
use_dynamic_shifting: bool = False,
157+
time_shift_type: str = "exponential",
156158
):
157159
if self.config.use_beta_sigmas and not is_scipy_available():
158160
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -232,7 +234,9 @@ def set_begin_index(self, begin_index: int = 0):
232234
"""
233235
self._begin_index = begin_index
234236

235-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
237+
def set_timesteps(
238+
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
239+
):
236240
"""
237241
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238242
@@ -242,6 +246,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
242246
device (`str` or `torch.device`, *optional*):
243247
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244248
"""
249+
if mu is not None:
250+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
251+
self.config.flow_shift = np.exp(mu)
245252
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
246253
if self.config.timestep_spacing == "linspace":
247254
timesteps = (

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
timestep_spacing: str = "linspace",
231231
steps_offset: int = 0,
232232
rescale_betas_zero_snr: bool = False,
233+
use_dynamic_shifting: bool = False,
234+
time_shift_type: str = "exponential",
233235
):
234236
if self.config.use_beta_sigmas and not is_scipy_available():
235237
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -330,6 +332,7 @@ def set_timesteps(
330332
self,
331333
num_inference_steps: int = None,
332334
device: Union[str, torch.device] = None,
335+
mu: Optional[float] = None,
333336
timesteps: Optional[List[int]] = None,
334337
):
335338
"""
@@ -345,6 +348,9 @@ def set_timesteps(
345348
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
346349
must be `None`, and `timestep_spacing` attribute will be ignored.
347350
"""
351+
if mu is not None:
352+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
353+
self.config.flow_shift = np.exp(mu)
348354
if num_inference_steps is None and timesteps is None:
349355
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
350356
if num_inference_steps is not None and timesteps is not None:

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def __init__(
169169
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
170170
lambda_min_clipped: float = -float("inf"),
171171
variance_type: Optional[str] = None,
172+
use_dynamic_shifting: bool = False,
173+
time_shift_type: str = "exponential",
172174
):
173175
if self.config.use_beta_sigmas and not is_scipy_available():
174176
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -301,6 +303,7 @@ def set_timesteps(
301303
self,
302304
num_inference_steps: int = None,
303305
device: Union[str, torch.device] = None,
306+
mu: Optional[float] = None,
304307
timesteps: Optional[List[int]] = None,
305308
):
306309
"""
@@ -316,6 +319,9 @@ def set_timesteps(
316319
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
317320
passed, `num_inference_steps` must be `None`.
318321
"""
322+
if mu is not None:
323+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
324+
self.config.flow_shift = np.exp(mu)
319325
if num_inference_steps is None and timesteps is None:
320326
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
321327
if num_inference_steps is not None and timesteps is not None:

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def __init__(
212212
steps_offset: int = 0,
213213
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
214214
rescale_betas_zero_snr: bool = False,
215+
use_dynamic_shifting: bool = False,
216+
time_shift_type: str = "exponential",
215217
):
216218
if self.config.use_beta_sigmas and not is_scipy_available():
217219
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -298,7 +300,9 @@ def set_begin_index(self, begin_index: int = 0):
298300
"""
299301
self._begin_index = begin_index
300302

301-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
303+
def set_timesteps(
304+
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
305+
):
302306
"""
303307
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
304308
@@ -309,6 +313,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
309313
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
310314
"""
311315
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
316+
if mu is not None:
317+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
318+
self.config.flow_shift = np.exp(mu)
312319
if self.config.timestep_spacing == "linspace":
313320
timesteps = (
314321
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)

0 commit comments

Comments
 (0)