Skip to content

Commit c169bcd

Browse files
authored
compile utils and version-gating (#1512)
1 parent 2544104 commit c169bcd

File tree

10 files changed

+182
-108
lines changed

10 files changed

+182
-108
lines changed

recipes/dev/lora_finetune_fsdp2.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def setup(self, cfg: DictConfig) -> None:
211211
self._metric_logger.log_config(cfg)
212212

213213
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
214-
self._model_compile = cfg.get("compile", False)
214+
self._compile = cfg.get("compile", False)
215215

216216
self._model = self._setup_model(
217217
cfg_model=cfg.model,
@@ -237,22 +237,14 @@ def setup(self, cfg: DictConfig) -> None:
237237

238238
# initialize loss
239239
self._loss_fn = config.instantiate(cfg.loss)
240-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
240+
241+
if self._compile:
242+
training.compile_loss(self.loss_fn, verbose=self._is_rank_zero)
243+
241244
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
242245
# set num_output_chunks for model
243246
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
244-
if self._model_compile:
245-
log.info("Compiling loss with torch.compile...")
246-
# For CEWithChunkedOutputLoss, if we compile the entire class
247-
# we lose the benefits from the chunked loss.
248-
# Therefore, we only compile the cross entropy function + upcasting
249-
self._loss_fn.compute_cross_entropy = torch.compile(
250-
self._loss_fn.compute_cross_entropy, backend=backend
251-
)
252-
else:
253-
if self._model_compile:
254-
log.info("Compiling loss with torch.compile...")
255-
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
247+
256248
log.info("Loss is initialized.")
257249

258250
# sampler and dataloader depend on the tokenizer and loss_fn and should be
@@ -328,12 +320,8 @@ def _setup_model(
328320
self.adapter_params = get_adapter_params(model)
329321
set_trainable_params(model, self.adapter_params)
330322

331-
if self._model_compile:
332-
log.info("Compiling model layers with torch.compile...")
333-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
334-
for m in reversed(list(model.modules())):
335-
if isinstance(m, modules.TransformerSelfAttentionLayer):
336-
m.compile(backend=backend)
323+
if self._compile:
324+
training.compile_model(self._model, verbose=self._is_rank_zero)
337325

338326
if enable_activation_checkpointing:
339327
training.set_activation_checkpointing(

recipes/full_finetune_distributed.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
87
import sys
98
import time
109

@@ -204,7 +203,7 @@ def setup(self, cfg: DictConfig) -> None:
204203

205204
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
206205

207-
self._model_compile = cfg.get("compile", False)
206+
self._compile = cfg.get("compile", False)
208207
self._model = self._setup_model(
209208
cfg_model=cfg.model,
210209
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
@@ -226,22 +225,14 @@ def setup(self, cfg: DictConfig) -> None:
226225

227226
# initialize loss
228227
self._loss_fn = config.instantiate(cfg.loss)
229-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
228+
229+
if self._compile:
230+
training.compile_loss(self.loss_fn, verbose=self._is_rank_zero)
231+
230232
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
231233
# set num_output_chunks for model
232234
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
233-
if self._model_compile:
234-
log.info("Compiling loss with torch.compile...")
235-
# For CEWithChunkedOutputLoss, if we compile the entire class
236-
# we lose the benefits from the chunked loss.
237-
# Therefore, we only compile the cross entropy function + upcasting
238-
self._loss_fn.compute_cross_entropy = torch.compile(
239-
self._loss_fn.compute_cross_entropy, backend=backend
240-
)
241-
else:
242-
if self._model_compile:
243-
log.info("Compiling loss with torch.compile...")
244-
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
235+
245236
log.info("Loss is initialized.")
246237

247238
# sampler and dataloader depend on the tokenizer and loss_fn and should be

recipes/full_finetune_single_device.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
87
import sys
98
import time
109
from functools import partial
@@ -207,11 +206,11 @@ def setup(self, cfg: DictConfig) -> None:
207206
# ``_setup_model`` handles initialization and loading the state dict. This method
208207
# should be called before ``_setup_optimizer`` since transforming the optimizer
209208
# state dict requires the model
210-
self._model_compile = cfg.compile
209+
self._compile = cfg.compile
211210
self._model = self._setup_model(
212211
cfg_model=cfg.model,
213212
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
214-
compile_model=self._model_compile,
213+
compile_model=self._compile,
215214
model_state_dict=ckpt_dict[training.MODEL_KEY],
216215
)
217216
self._tokenizer = config.instantiate(cfg.tokenizer)
@@ -229,22 +228,14 @@ def setup(self, cfg: DictConfig) -> None:
229228

230229
# initialize loss
231230
self._loss_fn = config.instantiate(cfg.loss)
232-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
231+
232+
if self._compile:
233+
training.compile_loss(self._loss_fn)
234+
233235
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
234236
# set num_output_chunks for model
235237
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
236-
if self._model_compile:
237-
log.info("Compiling loss with torch.compile...")
238-
# For CEWithChunkedOutputLoss, if we compile the entire class
239-
# we lose the benefits from the chunked loss.
240-
# Therefore, we only compile the cross entropy function + upcasting
241-
self._loss_fn.compute_cross_entropy = torch.compile(
242-
self._loss_fn.compute_cross_entropy, backend=backend
243-
)
244-
else:
245-
if self._model_compile:
246-
log.info("Compiling loss with torch.compile...")
247-
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
238+
248239
log.info("Loss is initialized.")
249240

250241
# sampler and dataloader depend on the tokenizer and loss_fn and should be
@@ -362,11 +353,7 @@ def _setup_model(
362353
model = config.instantiate(cfg_model)
363354

364355
if compile_model:
365-
log.info("Compiling model layers with torch.compile...")
366-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
367-
for m in reversed(list(model.modules())):
368-
if isinstance(m, modules.transformer.TransformerSelfAttentionLayer):
369-
m.compile(backend=backend)
356+
training.compile_model(model)
370357

371358
if enable_activation_checkpointing:
372359
training.set_activation_checkpointing(
@@ -537,7 +524,7 @@ def train(self) -> None:
537524
The core training loop. Supports training on subsets of the dataset using the
538525
``max_steps_per_epoch``.
539526
"""
540-
if self._model_compile:
527+
if self._compile:
541528
log.info(
542529
"NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
543530
)

recipes/lora_finetune_single_device.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
87
import sys
98
import time
109

@@ -211,7 +210,7 @@ def setup(self, cfg: DictConfig) -> None:
211210
# log config with parameter override
212211
self._metric_logger.log_config(cfg)
213212

214-
self._model_compile = cfg.compile
213+
self._compile = cfg.compile
215214
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
216215

217216
# set up model
@@ -241,22 +240,13 @@ def setup(self, cfg: DictConfig) -> None:
241240

242241
# initialize loss
243242
self._loss_fn = config.instantiate(cfg.loss)
244-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
243+
if self._compile:
244+
self._loss_fn = training.compile_loss(self._loss_fn)
245+
245246
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
246247
# set num_output_chunks for model
247248
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
248-
if self._model_compile:
249-
log.info("Compiling loss with torch.compile...")
250-
# For CEWithChunkedOutputLoss, if we compile the entire class
251-
# we lose the benefits from the chunked loss.
252-
# Therefore, we only compile the cross entropy function + upcasting
253-
self._loss_fn.compute_cross_entropy = torch.compile(
254-
self._loss_fn.compute_cross_entropy, backend=backend
255-
)
256-
else:
257-
if self._model_compile:
258-
log.info("Compiling loss with torch.compile...")
259-
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
249+
260250
log.info("Loss is initialized.")
261251

262252
# Dataloader depends on the tokenizer and loss_fn and should be
@@ -389,11 +379,7 @@ def _setup_model(
389379
set_trainable_params(model, self.adapter_params)
390380

391381
if compile_model:
392-
log.info("Compiling model layers with torch.compile...")
393-
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
394-
for m in reversed(list(model.modules())):
395-
if isinstance(m, modules.transformer.TransformerSelfAttentionLayer):
396-
m.compile(backend=backend)
382+
training.compile_model(model)
397383

398384
if enable_activation_checkpointing:
399385
training.set_activation_checkpointing(
@@ -607,7 +593,7 @@ def train(self) -> None:
607593
The core training loop.
608594
"""
609595

610-
if self._model_compile:
596+
if self._compile:
611597
log.info(
612598
"NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
613599
)

tests/recipes/test_lora_finetune_single_device.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch)
7878
# To workaround https://github.com/pytorch/torchtune/issues/676
7979
if compile:
8080
os.environ["TORCH_COMPILE_BACKEND"] = "aot_eager"
81-
8281
cmd = f"""
8382
tune run lora_finetune_single_device \
8483
--config {config} \

torchtune/models/gemma/_component_builders.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ def lora_gemma_self_attention(
289289
alpha=lora_alpha,
290290
dropout=lora_dropout,
291291
quantize_base=quantize_base,
292-
use_dora=use_dora,
293292
)
294293
if "q_proj" in lora_modules
295294
else (
@@ -306,7 +305,6 @@ def lora_gemma_self_attention(
306305
alpha=lora_alpha,
307306
dropout=lora_dropout,
308307
quantize_base=quantize_base,
309-
use_dora=use_dora,
310308
)
311309
if "k_proj" in lora_modules
312310
else (
@@ -323,7 +321,6 @@ def lora_gemma_self_attention(
323321
alpha=lora_alpha,
324322
dropout=lora_dropout,
325323
quantize_base=quantize_base,
326-
use_dora=use_dora,
327324
)
328325
if "v_proj" in lora_modules
329326
else (
@@ -340,7 +337,6 @@ def lora_gemma_self_attention(
340337
alpha=lora_alpha,
341338
dropout=lora_dropout,
342339
quantize_base=quantize_base,
343-
use_dora=use_dora,
344340
)
345341
if "output_proj" in lora_modules
346342
else (
@@ -385,7 +381,6 @@ def lora_gemma_mlp(
385381
alpha=lora_alpha,
386382
dropout=lora_dropout,
387383
quantize_base=quantize_base,
388-
use_dora=use_dora,
389384
)
390385
down_proj = adapter_cls(
391386
in_dim=hidden_dim,
@@ -394,7 +389,6 @@ def lora_gemma_mlp(
394389
alpha=lora_alpha,
395390
dropout=lora_dropout,
396391
quantize_base=quantize_base,
397-
use_dora=use_dora,
398392
)
399393
up_proj = adapter_cls(
400394
in_dim=dim,
@@ -403,7 +397,6 @@ def lora_gemma_mlp(
403397
alpha=lora_alpha,
404398
dropout=lora_dropout,
405399
quantize_base=quantize_base,
406-
use_dora=use_dora,
407400
)
408401
activation = nn.GELU(approximate="tanh")
409402

torchtune/models/gemma/transformer.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional
7+
from typing import List, Optional
88

99
import torch
1010
import torch.nn as nn
@@ -98,6 +98,28 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
9898
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool)
9999
)
100100

101+
@torch.compiler.disable
102+
def chunked_output(self, last_hidden_state: torch.Tensor) -> List[torch.Tensor]:
103+
"""
104+
Apply output projection in chunks. This should be applied in conjunction with
105+
:class:`~torchtune.modules.loss.CEWithChunkedOutputLoss` as upcasting to fp32 is done there.
106+
107+
To use this method, you should first call
108+
:func:`~torchtune.models.gemma.GemmaTransformerDecoder.set_num_output_chunks`.
109+
110+
Args:
111+
last_hidden_state (torch.Tensor): last hidden state of the decoder, having shape
112+
[b, seq_len, embed_dim].
113+
114+
Returns:
115+
List[torch.Tensor]: List of num_chunks output tensors, each with shape
116+
[b, seq_len/num_chunks, out_dim], where out_dim is usually the vocab size.
117+
"""
118+
return [
119+
F.linear(chunk, self.tok_embeddings.weight)
120+
for chunk in last_hidden_state.chunk(self.num_output_chunks, dim=1)
121+
]
122+
101123
def forward(
102124
self,
103125
tokens: torch.Tensor,
@@ -168,13 +190,7 @@ def forward(
168190
h = self.norm(h)
169191

170192
if self.num_output_chunks > 0:
171-
# shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size
172-
# Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe,
173-
# before calling forward. Upcasting it done inside of the loss function.
174-
output = [
175-
F.linear(chunk, self.tok_embeddings.weight)
176-
for chunk in h.chunk(self.num_output_chunks, dim=1)
177-
]
193+
output = self.chunked_output(h)
178194
else:
179195
# shape: [b, seq_len, out_dim]
180196
output = F.linear(h, self.tok_embeddings.weight).float()

0 commit comments

Comments
 (0)