From 4d0d60742413222048666af1ebcf2007c6937b0d Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:50:01 +0800 Subject: [PATCH 1/8] Update pipeline_flux.py have flux pipeline work with unipc/dpm schedulers --- src/diffusers/pipelines/flux/pipeline_flux.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 073d94750a02..132eac00e8e8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -29,7 +29,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, @@ -848,13 +848,20 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) + if isinstance(self.scheduler, (UniPCMultistepScheduler, DPMSolverMultistepScheduler)): + self.scheduler.config.use_flow_sigmas = True + self.scheduler.config.prediction_type = "flow_prediction" + self.scheduler.config.flow_shift = np.exp(mu) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From 7d136ba60cbf45579c2158fc0835948fed9a8117 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:18:25 +0800 Subject: [PATCH 2/8] clean code --- src/diffusers/pipelines/flux/pipeline_flux.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 132eac00e8e8..cc88ccc067b0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -840,6 +840,8 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if isinstance(self.scheduler, (UniPCMultistepScheduler, DPMSolverMultistepScheduler)): + sigmas = None image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -848,20 +850,13 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) - if isinstance(self.scheduler, (UniPCMultistepScheduler, DPMSolverMultistepScheduler)): - self.scheduler.config.use_flow_sigmas = True - self.scheduler.config.prediction_type = "flow_prediction" - self.scheduler.config.flow_shift = np.exp(mu) - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - else: - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From d5885d4f407cfebc0855aede5c49b7c0600e5e8d Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:21:56 +0800 Subject: [PATCH 3/8] Update scheduling_dpmsolver_multistep.py --- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1a648af5a008..19147b76e92c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -230,6 +230,8 @@ def __init__( timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -330,6 +332,7 @@ def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -345,6 +348,9 @@ def set_timesteps( based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: From 28851b6d7eacac41f6ec60baccba3429bfd2d871 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:26:52 +0800 Subject: [PATCH 4/8] Update scheduling_unipc_multistep.py --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 8b1f699b101a..0125d256ee58 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -212,6 +212,8 @@ def __init__( steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -298,7 +300,7 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -309,6 +311,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) From 074f0fa1ccde86883fac1567822393280d4f7bf2 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:25:36 +0800 Subject: [PATCH 5/8] Update pipeline_flux.py --- src/diffusers/pipelines/flux/pipeline_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index cc88ccc067b0..6e6e5a4c7fbd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -29,7 +29,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, FluxTransformer2DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler +from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, @@ -840,7 +840,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - if isinstance(self.scheduler, (UniPCMultistepScheduler, DPMSolverMultistepScheduler)): + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: sigmas = None image_seq_len = latents.shape[1] mu = calculate_shift( From 19d9c70a45da331cc66b081f10bc3028e815c1b4 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:30:57 +0800 Subject: [PATCH 6/8] Update scheduling_deis_multistep.py --- src/diffusers/schedulers/scheduling_deis_multistep.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 748a7e39c0b8..b223d1e1e30f 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -153,6 +153,8 @@ def __init__( flow_shift: Optional[float] = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -232,7 +234,7 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -242,6 +244,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + self.config.flow_shift = np.exp(mu) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( From d97d5a57fd8ba2c1b11f05fb2854cd71bcc43de5 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:51:32 +0800 Subject: [PATCH 7/8] Update scheduling_dpmsolver_singlestep.py --- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 9e3e830039bb..5604bde9c4af 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -169,6 +169,8 @@ def __init__( final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -301,6 +303,7 @@ def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -316,6 +319,9 @@ def set_timesteps( timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: From 2a6729dec922d94dbcb2259647e2811671a6112a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 14 Jul 2025 17:47:34 +0000 Subject: [PATCH 8/8] Apply style fixes --- src/diffusers/schedulers/scheduling_deis_multistep.py | 6 ++++-- src/diffusers/schedulers/scheduling_dpmsolver_multistep.py | 2 +- src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py | 2 +- src/diffusers/schedulers/scheduling_unipc_multistep.py | 6 ++++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index b223d1e1e30f..7d8685ba10c3 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -234,7 +234,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -245,7 +247,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" self.config.flow_shift = np.exp(mu) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 if self.config.timestep_spacing == "linspace": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 19147b76e92c..d07ff8b2007b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -349,7 +349,7 @@ def set_timesteps( must be `None`, and `timestep_spacing` attribute will be ignored. """ if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 5604bde9c4af..8663210a6244 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -320,7 +320,7 @@ def set_timesteps( passed, `num_inference_steps` must be `None`. """ if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 0125d256ee58..1d2378fd4f53 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -300,7 +300,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -312,7 +314,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential' + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = (