diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 073d94750a02..6e6e5a4c7fbd 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -840,6 +840,8 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 748a7e39c0b8..7d8685ba10c3 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -153,6 +153,8 @@ def __init__( flow_shift: Optional[float] = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -232,7 +234,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -242,6 +246,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 1a648af5a008..d07ff8b2007b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -230,6 +230,8 @@ def __init__( timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -330,6 +332,7 @@ def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -345,6 +348,9 @@ def set_timesteps( based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 9e3e830039bb..8663210a6244 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -169,6 +169,8 @@ def __init__( final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -301,6 +303,7 @@ def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, + mu: Optional[float] = None, timesteps: Optional[List[int]] = None, ): """ @@ -316,6 +319,9 @@ def set_timesteps( timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 8b1f699b101a..1d2378fd4f53 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -212,6 +212,8 @@ def __init__( steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -298,7 +300,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -309,6 +313,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 + if mu is not None: + assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" + self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)