Skip to content

Commit ed4f1d5

Browse files
committed
WIP-DEBUG-PROFILE torch.compile
ghstack-source-id: 0881d81 Pull Request resolved: #2644
1 parent 0991f97 commit ed4f1d5

File tree

13 files changed

+537
-39
lines changed

13 files changed

+537
-39
lines changed

recipes/configs/llama4/scout_17B_16E_full.yaml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ output_dir: /tmp/torchtune/llama4_17Bx16E/full
1818
model:
1919
_component_: torchtune.models.llama4.llama4_scout_17b_16e
2020

21-
tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
21+
tensor_parallel_dim: 1 # For multi-node training we recommend tensor_parallel_dim: 8
2222
tensor_parallel_plan:
2323
_component_: torchtune.models.llama4.decoder_only_tp_plan
2424
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
@@ -74,10 +74,10 @@ fsdp_cpu_offload: True
7474
# compile Dictionary with keys: "model", "loss", "optimizer_step"
7575
# enables torch.compile only for specified components.
7676
compile: False
77-
# model: True
78-
# loss: True
79-
# optimizer_step: False
80-
# scale_grads: True
77+
# model: True
78+
# loss: True
79+
# optimizer_step: True
80+
# scale_grads: True
8181

8282
# Reduced precision
8383
dtype: bf16
@@ -93,4 +93,17 @@ log_level: INFO # DEBUG, WARN, etc.
9393
# Useful for understanding how to optimize memory and performance
9494
profiler:
9595
_component_: torchtune.training.setup_torch_profiler
96-
enabled: False
96+
enabled: True
97+
output_dir: ${output_dir}/profiling_outputs
98+
cpu: True
99+
cuda: True
100+
profile_memory: True
101+
with_stack: True
102+
record_shapes: True
103+
with_flops: False
104+
wait_steps: 3
105+
warmup_steps: 3
106+
active_steps: 1
107+
num_cycles: 1
108+
109+
# enable_fp8_training: True

recipes/full_finetune_distributed.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from omegaconf import DictConfig, ListConfig
1717

1818
from torch import nn
19+
import torch.distributed as dist
1920
from torch.distributed import destroy_process_group, init_process_group
2021
from torch.distributed.tensor import DTensor
2122
from torch.distributed.tensor.parallel import parallelize_module
@@ -147,6 +148,10 @@ def __init__(self, cfg: DictConfig) -> None:
147148
offload_ops_to_cpu=self.fsdp_cpu_offload
148149
or self._enable_async_checkpointing,
149150
)
151+
# group_name = "torchtune-finetune"
152+
# pg = dist.distributed_c10d._get_default_group()
153+
# torch._C._distributed_c10d._register_process_group(group_name, pg)
154+
# init_process_group(self.distributed_backend, group_name=group_name)
150155
init_process_group(self.distributed_backend)
151156

152157
# Initialize distributed variables
@@ -328,6 +333,9 @@ def setup(self, cfg: DictConfig) -> None:
328333
compile = cfg.get("compile")
329334
compile_bool = bool(compile)
330335
self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
336+
self._compile_mode = None # "max-autotune-no-cudagraphs"
337+
# torch._inductor.config.cpp_wrapper = True
338+
# torch._dynamo.config.capture_scalar_outputs = True
331339

332340
self._compile_model = compile_bool
333341
self._compile_loss = compile_bool
@@ -343,7 +351,7 @@ def setup(self, cfg: DictConfig) -> None:
343351
self._grad_scaler = training.scale_grads_
344352
if self._compile_scale_grads:
345353
self._grad_scaler = torch.compile(
346-
self._grad_scaler, backend=self._compile_backend
354+
self._grad_scaler, backend=self._compile_backend, mode=self._compile_mode
347355
)
348356

349357
self._model = self._setup_model(
@@ -380,6 +388,7 @@ def setup(self, cfg: DictConfig) -> None:
380388
self._optimizer.step = torch.compile(
381389
self._optimizer.step,
382390
backend=self._compile_backend,
391+
mode=self._compile_mode
383392
)
384393

385394
if self._resume_from_checkpoint:
@@ -413,7 +422,7 @@ def setup(self, cfg: DictConfig) -> None:
413422
self._loss_fn.set_model_output(self._model)
414423

415424
if self._compile_loss:
416-
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
425+
training.compile_loss(self._loss_fn, mode=self._compile_mode, verbose=self._is_rank_zero)
417426

418427
utils.log_rank_zero(self._logger, "Loss is initialized.")
419428

@@ -586,7 +595,7 @@ def _setup_model(
586595
model = config.instantiate(cfg_model)
587596

588597
if self._compile_model:
589-
training.compile_model(model, verbose=self._is_rank_zero)
598+
training.compile_model(model, mode=self._compile_mode, verbose=self._is_rank_zero)
590599

591600
if self._enable_fp8_training:
592601
# Requires https://github.com/pytorch/pytorch/pull/148922
@@ -810,6 +819,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
810819

811820
with self.activations_handling_ctx:
812821
outputs = self._model(**batch)
822+
# print(f"XXX {dist.get_rank()} OUTPUTS:{outputs.shape} {outputs.dtype}")
813823

814824
# post process for third party loss functions
815825
if not isinstance(self._loss_fn, SFTLoss):
@@ -820,6 +830,7 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
820830

821831
# Compute loss
822832
loss = self._loss_fn(outputs, labels)
833+
# print(f"XXX {dist.get_rank()} LOSS:{loss}")
823834

824835
# free logits otherwise it peaks backward memory
825836
del outputs
@@ -895,6 +906,9 @@ def train(self) -> None:
895906
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
896907
self._dataloader.sampler.set_epoch(curr_epoch)
897908
for idx, batch in enumerate(self._dataloader):
909+
b_tokens = batch["tokens"]
910+
b_labels = batch["labels"]
911+
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} b_labels:{b_labels.shape} b_tokens:{b_tokens.shape}")
898912
# Start tracking CUDA memory for active steps for just the first epoch
899913
if (
900914
self._is_rank_zero
@@ -916,7 +930,9 @@ def train(self) -> None:
916930

917931
# Loss is normalized by default so we multiply by the number of tokens
918932
# This way we can normalize by the total number of tokens if we're accumulating gradients
933+
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_num_tokens:{current_num_tokens}")
919934
current_loss = self._loss_step(batch) * current_num_tokens
935+
# print(f"XXX R:{dist.get_rank()} BATCH:{idx} current_loss:{current_loss}")
920936
running_loss += current_loss
921937

922938
# For optimizer in backward, we need to normalize before calling backward
@@ -1068,6 +1084,26 @@ def cleanup(self) -> None:
10681084
self._metric_logger.close()
10691085
destroy_process_group()
10701086

1087+
# from torch.utils._python_dispatch import TorchDispatchMode
1088+
# import torch.utils._pytree as pytree
1089+
# from torch._higher_order_ops.flex_attention import flex_attention
1090+
#
1091+
#
1092+
#
1093+
# class Mode(TorchDispatchMode):
1094+
# def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1095+
# r = torch.distributed.get_rank()
1096+
# print(f"XXX RANK[{r}] MODE._torch_dispatch_ {func} {types}")
1097+
# for a in pytree.tree_leaves(args):
1098+
# if issubclass(type(a), torch.Tensor):
1099+
# print(f"XXX RANK[{r}] {a.dtype} {a.shape}")
1100+
# else:
1101+
# print(f"XXX RANK[{r}] {a}")
1102+
# return func(*args, **kwargs)
1103+
#
1104+
# def flex_attention_mode_call(mode, *args, **kwargs):
1105+
# return flex_attention(*args, **kwargs)
1106+
# flex_attention.py_impl(Mode)(flex_attention_mode_call)
10711107

10721108
@config.parse
10731109
def recipe_main(cfg: DictConfig) -> None:
@@ -1081,6 +1117,7 @@ def recipe_main(cfg: DictConfig) -> None:
10811117
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
10821118
recipe = FullFinetuneRecipeDistributed(cfg=cfg)
10831119
recipe.setup(cfg=cfg)
1120+
# with Mode():
10841121
recipe.train()
10851122
recipe.cleanup()
10861123

torchtune/models/llama4/_component_builders.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
TokenChoiceTopKRouter,
3939
)
4040
from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear
41+
from torchtune.utils._device import has_cuda_capability
4142

4243
"""
4344
Component builders for the Llama4 model.
@@ -180,6 +181,7 @@ def llama4_decoder(
180181
num_experts: int = 16,
181182
experts_per_token: int = 1,
182183
use_shared_expert: bool = True,
184+
use_grouped_mm: bool = True,
183185
use_qk_norm: bool = True,
184186
moe_every_n_layers: Optional[int] = None,
185187
mlp_hidden_dim: Optional[int] = None,
@@ -244,6 +246,11 @@ def llama4_decoder(
244246
raise ValueError(
245247
"Must pass local_chunk_size when enabling local chunked attention"
246248
)
249+
if use_grouped_mm and not has_cuda_capability(9, 0):
250+
torchtune.utils.get_logger("WARNING")(
251+
"Failed to use grouped mm, which is only supported on SM90 or later",
252+
)
253+
use_grouped_mm = False
247254
head_dim = embed_dim // num_heads
248255
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
249256

@@ -263,7 +270,6 @@ def llama4_decoder(
263270
)
264271
layers = []
265272
for i in range(num_layers):
266-
267273
mask_mod = None
268274
if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0:
269275
mask_mod = partial(
@@ -300,6 +306,7 @@ def llama4_decoder(
300306
num_experts=num_experts,
301307
experts_per_token=experts_per_token,
302308
use_shared_expert=use_shared_expert,
309+
use_grouped_mm=use_grouped_mm,
303310
)
304311
else:
305312
mlp_layer = llama4_mlp(dim=embed_dim, hidden_dim=mlp_hidden_dim)
@@ -355,6 +362,7 @@ def llama4_moe(
355362
num_experts: int = 8,
356363
experts_per_token: int = 1,
357364
use_shared_expert: bool = True,
365+
use_grouped_mm: bool = True,
358366
) -> MoE:
359367
"""
360368
Build the MoE layer associated with the Llama model.
@@ -631,6 +639,7 @@ def lora_llama4_decoder(
631639
raise ValueError(
632640
"Must pass local_chunk_size when enabling local chunked attention"
633641
)
642+
634643
head_dim = embed_dim // num_heads
635644
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
636645
if use_scaled_rope:
@@ -649,7 +658,6 @@ def lora_llama4_decoder(
649658
)
650659
layers = []
651660
for i in range(num_layers):
652-
653661
mask_mod = None
654662
if skip_rope_interval is not None and (i + 1) % skip_rope_interval != 0:
655663
mask_mod = partial(

torchtune/modules/attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def compile_flex_attention():
4747
# when compiled. To insulate it from the compiler, we wrap it with
4848
# compiler.disable so that it can be used regardless of whether the model
4949
# is compiled or not, and flex attention always remains compiled.
50-
@torch.compiler.disable(recursive=False)
50+
# @torch.compiler.disable(recursive=False)
5151
def compile_friendly_flex_attention(
5252
q: torch.Tensor,
5353
k: torch.Tensor,

torchtune/modules/moe/experts.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from torchtune.modules.peft import AdapterModule
1414

1515

16+
@torch._dynamo.allow_in_graph
17+
def _grouped_mm(x, w, offs):
18+
return torch._grouped_mm(x, w, offs=offs)
19+
20+
1621
class GroupedExperts(nn.Module):
1722
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
1823
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
@@ -31,6 +36,7 @@ def __init__(
3136
hidden_dim: int,
3237
num_experts: int = 1,
3338
activation: Callable = F.silu,
39+
use_grouped_mm: bool = False,
3440
):
3541
super().__init__()
3642
self.dim = dim
@@ -39,6 +45,8 @@ def __init__(
3945
self.down_proj = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
4046
self.up_proj = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
4147
self.act_fn = activation
48+
self.rank = torch.distributed.get_rank()
49+
self.use_grouped_mm = use_grouped_mm
4250

4351
def reset_parameters(self) -> None:
4452
# Default initialization used by torch.nn.Linear
@@ -50,6 +58,7 @@ def reset_parameters(self) -> None:
5058
# TODO: force no inference mode as a hack to get around
5159
# "Cannot set version_counter for inference tensor"
5260
@torch.inference_mode(mode=False)
61+
@torch._dynamo.disable(recursive=False)
5362
def forward(
5463
self,
5564
x: torch.Tensor,
@@ -64,28 +73,59 @@ def forward(
6473
Returns:
6574
torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)``
6675
"""
67-
68-
# a tuple of tensors indexed by experts
69-
# each with shape (tokens_per_expert(varying), dim)
70-
x = torch.split(
71-
x,
72-
split_size_or_sections=num_tokens_per_expert.tolist(),
73-
dim=0,
74-
)
75-
out_experts_splits = []
76-
for expert_idx, x_expert in enumerate(x):
77-
w1, w2, w3 = (
78-
self.gate_proj[expert_idx],
79-
self.down_proj[expert_idx],
80-
self.up_proj[expert_idx],
76+
if not self.use_grouped_mm:
77+
# a tuple of tensors indexed by experts
78+
# each with shape (tokens_per_expert(varying), dim)
79+
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
80+
if torch.compiler.is_compiling():
81+
for n in num_tokens_per_expert_list:
82+
torch._check_is_size(n)
83+
x = torch.split(
84+
x,
85+
split_size_or_sections=num_tokens_per_expert_list,
86+
dim=0,
8187
)
82-
h = self.act_fn(torch.matmul(x_expert, w1))
83-
h = h * torch.matmul(x_expert, w3)
84-
h = torch.matmul(h, w2)
85-
# h shape (tokens_per_expert(varying), dim)
86-
out_experts_splits.append(h)
87-
out = torch.cat(out_experts_splits, dim=0)
88+
out_experts_splits = []
89+
for expert_idx, x_expert in enumerate(x):
90+
w1, w2, w3 = (
91+
self.gate_proj[expert_idx],
92+
self.down_proj[expert_idx],
93+
self.up_proj[expert_idx],
94+
)
95+
h = self.act_fn(torch.matmul(x_expert, w1))
96+
h = h * torch.matmul(x_expert, w3)
97+
h = torch.matmul(h, w2)
98+
# h shape (tokens_per_expert(varying), dim)
99+
out_experts_splits.append(h)
100+
out = torch.cat(out_experts_splits, dim=0)
101+
102+
return out
88103

104+
# grouped mm implementation
105+
if num_tokens_per_expert is not None:
106+
# https://github.com/pytorch/pytorch/pull/150374
107+
# NOTE: torch._gouped_mm requires bf16 dtypes
108+
# and shapes to be multiple of 8
109+
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
110+
# grouped mm between a 2D tensor and a 3D tensor
111+
assert x.dim() == 2
112+
else:
113+
offsets = None
114+
# fall back to regular bmm between 3D tensors
115+
assert x.dim() == 3
116+
117+
w1, w2, w3 = (
118+
self.gate_proj,
119+
self.down_proj,
120+
self.up_proj,
121+
)
122+
assert (
123+
x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16
124+
), "torch._grouped_mm only supports bf16 dtypes"
125+
h = F.silu(_grouped_mm(x, w1, offs=offsets))
126+
h = h * _grouped_mm(x, w3, offs=offsets)
127+
out = _grouped_mm(h, w2, offs=offsets)
128+
out[offsets[-1] :].zero_()
89129
return out
90130

91131

0 commit comments

Comments
 (0)