Skip to content

Commit 168042e

Browse files
committed
WIP-DEBUG-PROFILE torch.compile
ghstack-source-id: b8d6007 Pull Request resolved: #2644
1 parent 0991f97 commit 168042e

File tree

11 files changed

+507
-38
lines changed

11 files changed

+507
-38
lines changed

recipes/configs/llama4/scout_17B_16E_full.yaml

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ output_dir: /tmp/torchtune/llama4_17Bx16E/full
1818
model:
1919
_component_: torchtune.models.llama4.llama4_scout_17b_16e
2020

21-
tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
21+
tensor_parallel_dim: 1 # For multi-node training we recommend tensor_parallel_dim: 8
2222
tensor_parallel_plan:
2323
_component_: torchtune.models.llama4.decoder_only_tp_plan
2424
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
@@ -73,11 +73,11 @@ fsdp_cpu_offload: True
7373
# compile False means no torch.compile
7474
# compile Dictionary with keys: "model", "loss", "optimizer_step"
7575
# enables torch.compile only for specified components.
76-
compile: False
77-
# model: True
78-
# loss: True
79-
# optimizer_step: False
80-
# scale_grads: True
76+
compile: True
77+
# model: True
78+
# loss: True
79+
# optimizer_step: True
80+
# scale_grads: True
8181

8282
# Reduced precision
8383
dtype: bf16
@@ -93,4 +93,17 @@ log_level: INFO # DEBUG, WARN, etc.
9393
# Useful for understanding how to optimize memory and performance
9494
profiler:
9595
_component_: torchtune.training.setup_torch_profiler
96-
enabled: False
96+
enabled: True
97+
output_dir: ${output_dir}/profiling_outputs
98+
cpu: True
99+
cuda: True
100+
profile_memory: True
101+
with_stack: True
102+
record_shapes: True
103+
with_flops: False
104+
wait_steps: 5
105+
warmup_steps: 3
106+
active_steps: 1
107+
num_cycles: 1
108+
109+
# enable_fp8_training: True

recipes/full_finetune_distributed.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from omegaconf import DictConfig, ListConfig
1717

1818
from torch import nn
19+
import torch.distributed as dist
1920
from torch.distributed import destroy_process_group, init_process_group
2021
from torch.distributed.tensor import DTensor
2122
from torch.distributed.tensor.parallel import parallelize_module
@@ -147,7 +148,10 @@ def __init__(self, cfg: DictConfig) -> None:
147148
offload_ops_to_cpu=self.fsdp_cpu_offload
148149
or self._enable_async_checkpointing,
149150
)
150-
init_process_group(self.distributed_backend)
151+
group_name = "torchtune-finetune"
152+
init_process_group(self.distributed_backend, group_name=group_name)
153+
pg = dist.distributed_c10d._get_default_group()
154+
torch._C._distributed_c10d._register_process_group(group_name, pg)
151155

152156
# Initialize distributed variables
153157
self.world_size, self.rank = utils.get_world_size_and_rank()
@@ -328,6 +332,8 @@ def setup(self, cfg: DictConfig) -> None:
328332
compile = cfg.get("compile")
329333
compile_bool = bool(compile)
330334
self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
335+
self._compile_mode = None # "max-autotune-no-cudagraphs"
336+
torch._inductor.config.cpp_wrapper = True
331337

332338
self._compile_model = compile_bool
333339
self._compile_loss = compile_bool
@@ -343,7 +349,7 @@ def setup(self, cfg: DictConfig) -> None:
343349
self._grad_scaler = training.scale_grads_
344350
if self._compile_scale_grads:
345351
self._grad_scaler = torch.compile(
346-
self._grad_scaler, backend=self._compile_backend
352+
self._grad_scaler, backend=self._compile_backend, mode=self._compile_mode
347353
)
348354

349355
self._model = self._setup_model(
@@ -380,6 +386,7 @@ def setup(self, cfg: DictConfig) -> None:
380386
self._optimizer.step = torch.compile(
381387
self._optimizer.step,
382388
backend=self._compile_backend,
389+
mode=self._compile_mode
383390
)
384391

385392
if self._resume_from_checkpoint:
@@ -413,7 +420,7 @@ def setup(self, cfg: DictConfig) -> None:
413420
self._loss_fn.set_model_output(self._model)
414421

415422
if self._compile_loss:
416-
training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
423+
training.compile_loss(self._loss_fn, mode=self._compile_mode, verbose=self._is_rank_zero)
417424

418425
utils.log_rank_zero(self._logger, "Loss is initialized.")
419426

@@ -586,7 +593,7 @@ def _setup_model(
586593
model = config.instantiate(cfg_model)
587594

588595
if self._compile_model:
589-
training.compile_model(model, verbose=self._is_rank_zero)
596+
training.compile_model(model, mode=self._compile_mode, verbose=self._is_rank_zero)
590597

591598
if self._enable_fp8_training:
592599
# Requires https://github.com/pytorch/pytorch/pull/148922
@@ -1068,6 +1075,26 @@ def cleanup(self) -> None:
10681075
self._metric_logger.close()
10691076
destroy_process_group()
10701077

1078+
# from torch.utils._python_dispatch import TorchDispatchMode
1079+
# import torch.utils._pytree as pytree
1080+
# from torch._higher_order_ops.flex_attention import flex_attention
1081+
#
1082+
#
1083+
#
1084+
# class Mode(TorchDispatchMode):
1085+
# def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1086+
# r = torch.distributed.get_rank()
1087+
# print(f"XXX RANK[{r}] MODE._torch_dispatch_ {func} {types}")
1088+
# for a in pytree.tree_leaves(args):
1089+
# if issubclass(type(a), torch.Tensor):
1090+
# print(f"XXX RANK[{r}] {a.dtype} {a.shape}")
1091+
# else:
1092+
# print(f"XXX RANK[{r}] {a}")
1093+
# return func(*args, **kwargs)
1094+
#
1095+
# def flex_attention_mode_call(mode, *args, **kwargs):
1096+
# return flex_attention(*args, **kwargs)
1097+
# flex_attention.py_impl(Mode)(flex_attention_mode_call)
10711098

10721099
@config.parse
10731100
def recipe_main(cfg: DictConfig) -> None:
@@ -1081,6 +1108,7 @@ def recipe_main(cfg: DictConfig) -> None:
10811108
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
10821109
recipe = FullFinetuneRecipeDistributed(cfg=cfg)
10831110
recipe.setup(cfg=cfg)
1111+
# with Mode():
10841112
recipe.train()
10851113
recipe.cleanup()
10861114

torchtune/modules/attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def compile_flex_attention():
4747
# when compiled. To insulate it from the compiler, we wrap it with
4848
# compiler.disable so that it can be used regardless of whether the model
4949
# is compiled or not, and flex attention always remains compiled.
50-
@torch.compiler.disable(recursive=False)
50+
# @torch.compiler.disable(recursive=False)
5151
def compile_friendly_flex_attention(
5252
q: torch.Tensor,
5353
k: torch.Tensor,

torchtune/modules/moe/experts.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
from torch import nn
1212
from torch.nn import functional as F
1313
from torchtune.modules.peft import AdapterModule
14+
from torchtune.modules.moe.moe import USE_GROUPED_MM
1415

1516

17+
@torch._dynamo.allow_in_graph
18+
def _grouped_mm(x, w, offs):
19+
return torch._grouped_mm(x, w, offs=offs)
20+
1621
class GroupedExperts(nn.Module):
1722
"""This class implements the grouped experts layer used in Mixture of Experts. Each expert
1823
is a variant of the Gated Linear Units network. See more details in https://arxiv.org/pdf/2002.05202.
@@ -50,6 +55,7 @@ def reset_parameters(self) -> None:
5055
# TODO: force no inference mode as a hack to get around
5156
# "Cannot set version_counter for inference tensor"
5257
@torch.inference_mode(mode=False)
58+
# @torch._dynamo.disable(recursive=False)
5359
def forward(
5460
self,
5561
x: torch.Tensor,
@@ -64,27 +70,57 @@ def forward(
6470
Returns:
6571
torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)``
6672
"""
73+
self.use_grouped_mm = USE_GROUPED_MM
74+
if not self.use_grouped_mm:
75+
# a tuple of tensors indexed by experts
76+
# each with shape (tokens_per_expert(varying), dim)
77+
x = torch.split(
78+
x,
79+
split_size_or_sections=num_tokens_per_expert.tolist(),
80+
dim=0,
81+
)
82+
out_experts_splits = []
83+
for expert_idx, x_expert in enumerate(x):
84+
w1, w2, w3 = (
85+
self.gate_proj[expert_idx],
86+
self.down_proj[expert_idx],
87+
self.up_proj[expert_idx],
88+
)
89+
h = self.act_fn(torch.matmul(x_expert, w1))
90+
h = h * torch.matmul(x_expert, w3)
91+
h = torch.matmul(h, w2)
92+
# h shape (tokens_per_expert(varying), dim)
93+
out_experts_splits.append(h)
94+
out = torch.cat(out_experts_splits, dim=0)
6795

68-
# a tuple of tensors indexed by experts
69-
# each with shape (tokens_per_expert(varying), dim)
70-
x = torch.split(
71-
x,
72-
split_size_or_sections=num_tokens_per_expert.tolist(),
73-
dim=0,
74-
)
75-
out_experts_splits = []
76-
for expert_idx, x_expert in enumerate(x):
77-
w1, w2, w3 = (
78-
self.gate_proj[expert_idx],
79-
self.down_proj[expert_idx],
80-
self.up_proj[expert_idx],
96+
return out
97+
98+
# grouped mm implementation
99+
if num_tokens_per_expert is not None:
100+
# https://github.com/pytorch/pytorch/pull/150374
101+
# NOTE: torch._gouped_mm requires bf16 dtypes
102+
# and shapes to be multiple of 8
103+
offsets = torch.cumsum(
104+
num_tokens_per_expert, dim=0, dtype=torch.int32
81105
)
82-
h = self.act_fn(torch.matmul(x_expert, w1))
83-
h = h * torch.matmul(x_expert, w3)
84-
h = torch.matmul(h, w2)
85-
# h shape (tokens_per_expert(varying), dim)
86-
out_experts_splits.append(h)
87-
out = torch.cat(out_experts_splits, dim=0)
106+
# grouped mm between a 2D tensor and a 3D tensor
107+
assert x.dim() == 2
108+
else:
109+
offsets = None
110+
# fall back to regular bmm between 3D tensors
111+
assert x.dim() == 3
112+
113+
w1, w2, w3 = (
114+
self.gate_proj,
115+
self.down_proj,
116+
self.up_proj,
117+
)
118+
assert (
119+
x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16
120+
), "torch._grouped_mm only supports bf16 dtypes"
121+
h = F.silu(_grouped_mm(x, w1, offs=offsets))
122+
h = h * _grouped_mm(x, w3, offs=offsets)
123+
out = _grouped_mm(h, w2, offs=offsets)
88124

89125
return out
90126

0 commit comments

Comments
 (0)