Official implementation of ICLR 2025 "Sampling Demon" (arXiv:2410.05760).
This repository contains the official implementation of Sampling Demon, an inference-time, backpropagation-free preference alignment method for diffusion models. By aligning the denoising process with user preferences via stochastic optimization, Sampling Demon enables the use of non-differentiable reward signals—such as those from Visual-Language Model (VLM) APIs and human judgements—without requiring retraining or fine-tuning of the underlying diffusion model.
- Overview
- Installation
- Usage
- Low-Level API
- Experiments and Results
- Credits and Acknowledgments
- Citation
- License
Diffusion models have revolutionized image generation; however, aligning these models with diverse user preferences remains a significant challenge. Traditional approaches rely either on costly retraining or require differentiable reward functions, limiting their scope when using non-differentiable sources such as VLM APIs and human feedback.
Sampling Demon overcomes these limitations by steering the denoising process via stochastic optimization at inference time. Inspired by Maxwell's Demon, our method evaluates multiple candidate noise perturbations and selectively synthesizes the ones that yield higher rewards. Our contributions highlight:
- Backpropagation-Free Alignment: Incorporate non-differentiable reward signals directly into the inference process.
- Plug-and-Play Integration: Seamlessly integrate with existing diffusion models without additional training.
- Theoretical and Empirical Validation: We provide both theoretical insights and comprehensive experimental evidence showing significant improvements in aesthetic scores.
- Broad Applicability: Our approach supports reward signals from various sources, including VLM APIs and human judgements.
To install the required packages, run:
conda env create -f environment.yml # The build takes about 30 minutes on our machine :(
pip install -e .
Note:
If you experience issues with PyTorch versioning, try uninstalling torch-related packages and reinstall using:pip3 install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1Alternatively, install the packages from
requirements.txt
one by one if needed.
Develop your own pipeline by subclassing the DemonGenerater
abstract class. Override the rewards
method to integrate your custom reward function (i.e., mapping a list of PIL images to reward scores).
class YourRewardGenerator(DemonGenerater):
def rewards(self, pils: List[Image]) -> List[float]:
"""
Implement your custom reward function here.
"""
return your_reward_function(pils)
...
generator = YourRewardGenerator(
beta=0.1,
tau="adaptive",
K=16,
T=64,
demon_type="tanh", # or "boltzmann", "optimal"
r_of_c="consistency", # or "baseline"
# c_steps=20, # Meaningful only when r_of_c="baseline"
ode_after=0.11, # Recommended value for Stable Diffusion
cfg=2, # Recommended value in (0, 5]
save_pils=True,
experiment_directory="experiments/your_experiment",
)
generator.generate(prompt=text)
See the examples in pipelines/vllm_generate.py
and pipelines/choose_generate.py
.
The repository includes several example pipelines that demonstrate Sampling Demon in action. These pipelines illustrate how to align diffusion models with user preferences using various reward functions.
This pipeline reproduces the results of the Aesthetics Animal Evaluation experiment on the paper (Please refer to the paper for configuration guidelines):
python3 pipelines/aesthetic_animal_eval.py --r_of_c "consistency"
This pipeline leverages a Visual-Language Model (VLM) as the reward function to generate images:
python pipelines/vllm_generate.py --model "gemini" --K 16 --T 128 --beta 0.1
Interact with the algorithm via the manual selection pipeline, which provides a user interface for selecting preferred outcomes:
python pipelines/choose_generate.py --text "A boulder in elevator" --K 16 --T 128
For advanced users who wish to modify Sampling Demon at a lower level, we provide a low-level API that was integral to our research. The following snippets demonstrate key functionalities:
condition = {
"prompts": ["On Moon", "Astronaut", "Riding a donkey"],
"cfgs": [3, 2, 4]
}
steps = 20
x = get_init_latent() # sigma is 14.6488 for Stable Diffusion
x = odeint(x, condition, steps)
pil = from_latent_to_pil(x)
condition = {
"prompts": ["An astronaut riding a horse on Mars."],
"cfgs": [5]
}
x = from_pil_to_latent(pil)
x = oderevert(x, condition)
x = odeint(x, condition, 20)
pil = from_latent_to_pil(x)
old_condition = {
"prompts": ["An astronaut riding a horse on Mars."],
"cfgs": [5]
}
new_condition = {
"prompts": ["On Moon", "Astronaut", "Riding a donkey"],
"cfgs": [3, 2, 4]
}
steps = 20
sigma = 14
beta = 0.125
x = from_pil_to_latent(pil)
x = oderevert(x, old_condition, start_t=sigma)
x = sdeint(x, new_condition, beta, steps, start_t=sigma)
pil = from_latent_to_pil(x)
- For SDv1.5:
- Please switch to the
mini
branch for the SDv1.5-compatible version of the code. pipelines/
is compatible with SDv1.5.
- Please switch to the
- Running the test:
- pytest is used for testing. To run the tests, use the command
pytest tests
- Specifically, the low-level API demonstration is identical to
tests/test_api.py
.
- pytest is used for testing. To run the tests, use the command
- Aesthetic Model Checkpoint: Provided by DDPO.
- Safety Checker: Utilizes the Stable Diffusion Safety Checker from CompVis.
- Contributors: For questions or suggestions, please raise an issue or contact the author.
If you find this code useful in your research, please consider citing our paper:
@inproceedings{
yeh2025trainingfree,
title={Training-Free Diffusion Model Alignment with Sampling Demons},
author={Po-Hung Yeh, Kuang-Huei Lee, Jun-cheng Chen},
booktitle={International Conference on Learning Representations},
year={2025},
}