Skip to content

Commit 8a52d19

Browse files
committed
feedback
1 parent d3ccda8 commit 8a52d19

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

docs/source/en/optimization/fp16.md

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,7 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a
175175
### Regional compilation
176176

177177
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
178-
For many diffusion architectures, this delivers the same runtime speed-ups as full-graph compilation and reduces compile time by 8–10x.
179-
180-
There are two implementations of regional compilation.
181-
182-
- The Diffusers version, [`~ModelMixin.compile_repeated_blocks`], is more explicit and is easier to customize.
183-
- The Accelerate version, [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78), automatically selects which regions to compile and is less customizable. It is ideal for fast experiments.
184-
185-
<hfoptions id="regional-compilation">
186-
<hfoption id="compile_repeated_blocks">
178+
For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
187179

188180
Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
189181

@@ -192,13 +184,13 @@ Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `tor
192184
import torch
193185
from diffusers import StableDiffusionXLPipeline
194186

195-
pipe = StableDiffusionXLPipeline.from_pretrained(
187+
pipeline = StableDiffusionXLPipeline.from_pretrained(
196188
"stabilityai/stable-diffusion-xl-base-1.0",
197189
torch_dtype=torch.float16,
198190
).to("cuda")
199191

200-
# Compile only the repeated Transformer layers inside the UNet
201-
pipe.unet.compile_repeated_blocks(fullgraph=True)
192+
# compile only the repeated transformer layers inside the UNet
193+
pipeline.unet.compile_repeated_blocks(fullgraph=True)
202194
```
203195

204196
To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
@@ -209,10 +201,7 @@ class MyUNet(ModelMixin):
209201
```
210202

211203
> [!TIP]
212-
> For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
213-
214-
</hfoption>
215-
<hfoption id="compile_regions">
204+
> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
216205

217206
There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
218207

@@ -230,9 +219,6 @@ pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph
230219

231220
[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
232221

233-
</hfoption>
234-
</hfoptions>
235-
236222
### Graph breaks
237223

238224
It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
@@ -310,4 +296,6 @@ pipeline.fuse_qkv_projections()
310296

311297
## Resources
312298

313-
Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup.
299+
- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
300+
301+
These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).

0 commit comments

Comments
 (0)