diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 746063f9d619..592e5d77669e 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1330,7 +1330,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # controlnet(s) inference controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) controlnet_image = vae.encode(controlnet_image).latent_dist.sample() - controlnet_image = controlnet_image * vae.config.scaling_factor + controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor control_block_res_samples = controlnet( hidden_states=noisy_model_input,