Skip to content

Commit a488dd4

Browse files
committed
WIP-DEBUG-PROFILE torch.compile
ghstack-source-id: 60b1db8 Pull Request resolved: #2644
1 parent 7ee3c7a commit a488dd4

File tree

6 files changed

+355
-26
lines changed

6 files changed

+355
-26
lines changed

recipes/configs/llama4/scout_17B_16E_full.yaml

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ output_dir: /tmp/torchtune/llama4_17Bx16E/full
1818
model:
1919
_component_: torchtune.models.llama4.llama4_scout_17b_16e
2020

21-
tensor_parallel_dim: 2 # For multi-node training we recommend tensor_parallel_dim: 8
21+
tensor_parallel_dim: 1 # For multi-node training we recommend tensor_parallel_dim: 8
2222
tensor_parallel_plan:
2323
_component_: torchtune.models.llama4.decoder_only_tp_plan
2424
data_parallel_shard_dim: -1 # Will infer based on TP dim, effectively controls FSDP
@@ -73,10 +73,10 @@ fsdp_cpu_offload: True
7373
# compile False means no torch.compile
7474
# compile Dictionary with keys: "model", "loss", "optimizer_step"
7575
# enables torch.compile only for specified components.
76-
compile: False
76+
compile: True
7777
# model: True
7878
# loss: True
79-
# optimizer_step: False
79+
# optimizer_step: True
8080
# scale_grads: True
8181

8282
# Reduced precision
@@ -92,4 +92,15 @@ log_peak_memory_stats: True
9292
# Useful for understanding how to optimize memory and performance
9393
profiler:
9494
_component_: torchtune.training.setup_torch_profiler
95-
enabled: False
95+
enabled: True
96+
output_dir: ${output_dir}/profiling_outputs
97+
cpu: True
98+
cuda: True
99+
profile_memory: True
100+
with_stack: True
101+
record_shapes: True
102+
with_flops: False
103+
wait_steps: 5
104+
warmup_steps: 3
105+
active_steps: 1
106+
num_cycles: 1

torchtune/models/llama4/_model_builders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def llama4_scout_17b_16e(
7878
norm_eps=1e-5,
7979
num_experts=16,
8080
use_shared_expert=True,
81-
skip_rope_interval=4,
81+
skip_rope_interval=None,
8282
attention_chunk_size=8192,
8383
use_scaled_rope=True,
8484
)
@@ -149,7 +149,7 @@ def llama4_maverick_17b_128e(
149149
use_qk_norm=False,
150150
moe_every_n_layers=2,
151151
mlp_hidden_dim=16384,
152-
skip_rope_interval=4,
152+
skip_rope_interval=None,
153153
attention_chunk_size=8192,
154154
)
155155
return EarlyFusionModel(

torchtune/modules/moe/experts.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def reset_parameters(self) -> None:
5050
# TODO: force no inference mode as a hack to get around
5151
# "Cannot set version_counter for inference tensor"
5252
@torch.inference_mode(mode=False)
53+
@torch._dynamo.disable(recursive=False)
5354
def forward(
5455
self,
5556
x: torch.Tensor,
@@ -64,27 +65,57 @@ def forward(
6465
Returns:
6566
torch.Tensor: tensor with shape ``(bsz * seq_len * experts_per_token, dim)``
6667
"""
68+
self.use_grouped_mm = True
69+
if not self.use_grouped_mm:
70+
# a tuple of tensors indexed by experts
71+
# each with shape (tokens_per_expert(varying), dim)
72+
x = torch.split(
73+
x,
74+
split_size_or_sections=num_tokens_per_expert.tolist(),
75+
dim=0,
76+
)
77+
out_experts_splits = []
78+
for expert_idx, x_expert in enumerate(x):
79+
w1, w2, w3 = (
80+
self.gate_proj[expert_idx],
81+
self.down_proj[expert_idx],
82+
self.up_proj[expert_idx],
83+
)
84+
h = self.act_fn(torch.matmul(x_expert, w1))
85+
h = h * torch.matmul(x_expert, w3)
86+
h = torch.matmul(h, w2)
87+
# h shape (tokens_per_expert(varying), dim)
88+
out_experts_splits.append(h)
89+
out = torch.cat(out_experts_splits, dim=0)
6790

68-
# a tuple of tensors indexed by experts
69-
# each with shape (tokens_per_expert(varying), dim)
70-
x = torch.split(
71-
x,
72-
split_size_or_sections=num_tokens_per_expert.tolist(),
73-
dim=0,
74-
)
75-
out_experts_splits = []
76-
for expert_idx, x_expert in enumerate(x):
77-
w1, w2, w3 = (
78-
self.gate_proj[expert_idx],
79-
self.down_proj[expert_idx],
80-
self.up_proj[expert_idx],
91+
return out
92+
93+
# grouped mm implementation
94+
if num_tokens_per_expert is not None:
95+
# https://github.com/pytorch/pytorch/pull/150374
96+
# NOTE: torch._gouped_mm requires bf16 dtypes
97+
# and shapes to be multiple of 8
98+
offsets = torch.cumsum(
99+
num_tokens_per_expert, dim=0, dtype=torch.int32
81100
)
82-
h = self.act_fn(torch.matmul(x_expert, w1))
83-
h = h * torch.matmul(x_expert, w3)
84-
h = torch.matmul(h, w2)
85-
# h shape (tokens_per_expert(varying), dim)
86-
out_experts_splits.append(h)
87-
out = torch.cat(out_experts_splits, dim=0)
101+
# grouped mm between a 2D tensor and a 3D tensor
102+
assert x.dim() == 2
103+
else:
104+
offsets = None
105+
# fall back to regular bmm between 3D tensors
106+
assert x.dim() == 3
107+
108+
w1, w2, w3 = (
109+
self.gate_proj,
110+
self.down_proj,
111+
self.up_proj,
112+
)
113+
assert (
114+
x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16
115+
), "torch._grouped_mm only supports bf16 dtypes"
116+
h = F.silu(torch._grouped_mm(x, w1, offs=offsets))
117+
h = h * torch._grouped_mm(x, w3, offs=offsets)
118+
out = torch._grouped_mm(h, w2, offs=offsets)
88119

89120
return out
90121

torchtune/modules/moe/indices.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
import triton
9+
import triton.language as tl
10+
11+
12+
__all__ = ["generate_permute_indices"]
13+
14+
15+
# parallelized kernel
16+
@triton.jit
17+
def _fill_indices_kernel(
18+
tokens_per_expert_group_ptr,
19+
start_index_values_ptr,
20+
write_offsets_ptr,
21+
output_ptr,
22+
experts_per_rank: tl.constexpr,
23+
num_ranks: tl.constexpr,
24+
BLOCK_SIZE: tl.constexpr, # Number of threads per block
25+
):
26+
pid = tl.program_id(axis=0)
27+
num_programs = tl.num_programs(axis=0)
28+
29+
# map programs (blocks) to the experts and loop (grid stride) if needed
30+
for expert_id in range(pid, experts_per_rank, num_programs):
31+
# read this experts write offset
32+
write_offset = tl.load(write_offsets_ptr + expert_id)
33+
34+
# loop over all ranks
35+
for r in range(num_ranks):
36+
# index into tokens_per_expert_group array
37+
i = r * experts_per_rank + expert_id
38+
39+
# load start index and number of tokens for this expert-rank pair
40+
start_index = tl.load(start_index_values_ptr + i)
41+
length = tl.load(tokens_per_expert_group_ptr + i)
42+
43+
# each thread in block processes tokens in parallel
44+
offsets = tl.arange(0, BLOCK_SIZE)
45+
46+
# tokens are processed in chunks of BLOCK_SIZE
47+
for chunk_start in range(0, length, BLOCK_SIZE):
48+
chunk_offsets = chunk_start + offsets
49+
50+
# mask valid indices
51+
mask = chunk_offsets < length
52+
53+
values = start_index + chunk_offsets
54+
55+
# destination
56+
dest_indices = write_offset + chunk_offsets
57+
58+
# store
59+
tl.store(output_ptr + dest_indices, values, mask=mask)
60+
61+
# update write offset for next rank
62+
write_offset += length
63+
64+
65+
# ==============
66+
# wrapper
67+
# ==============
68+
69+
70+
def fill_indices_wrapper(
71+
tokens_per_expert_group: torch.Tensor,
72+
start_index_values: torch.Tensor,
73+
write_offsets: torch.Tensor,
74+
experts_per_rank: int,
75+
num_ranks: int,
76+
max_len: int,
77+
block_size: int = 128,
78+
max_blocks: int = 1024, # cap on total number of blocks to launch
79+
):
80+
# preallocate output
81+
permuted_indices = torch.full(
82+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
83+
)
84+
85+
# write offsets is per local expert...
86+
num_blocks = min(experts_per_rank, max_blocks)
87+
# grid = one block per expert unless capped and then we loop...
88+
grid = (num_blocks,)
89+
90+
# launch kernel
91+
_fill_indices_kernel[grid](
92+
tokens_per_expert_group,
93+
start_index_values,
94+
write_offsets,
95+
permuted_indices,
96+
experts_per_rank,
97+
num_ranks,
98+
BLOCK_SIZE=block_size,
99+
)
100+
return permuted_indices
101+
102+
103+
# reference
104+
def fill_indices_cpu(
105+
tokens_per_expert_group: torch.Tensor,
106+
start_index_values: torch.Tensor,
107+
write_offsets: torch.Tensor,
108+
experts_per_rank: int,
109+
num_ranks: int,
110+
max_len: int,
111+
):
112+
# We need to preallocate the output - we ignore device and force it on cpu
113+
# device = tokens_per_expert_group.device
114+
permuted_indices = torch.full(
115+
(max_len,),
116+
-1,
117+
dtype=torch.int32,
118+
) # device=device)
119+
# Fill the permuted indices
120+
# For each local expert
121+
for e in range(experts_per_rank):
122+
write_start = write_offsets[e].item()
123+
# For each remote rank
124+
for r in range(num_ranks):
125+
i = r * experts_per_rank + e
126+
start_index = start_index_values[i].item()
127+
length = tokens_per_expert_group[i].item()
128+
# Fill in the indices
129+
if length > 0:
130+
end_idx = min(write_start + length, max_len)
131+
permuted_indices[write_start:end_idx] = torch.arange(
132+
start_index,
133+
start_index + (end_idx - write_start),
134+
dtype=torch.int32,
135+
# device=device,
136+
)
137+
write_start += length
138+
return permuted_indices
139+
140+
141+
def generate_permute_indices(
142+
tokens_per_expert_group: torch.Tensor,
143+
experts_per_rank: int,
144+
num_ranks: int,
145+
max_len: int,
146+
alignment: int,
147+
use_cpu: bool = False,
148+
):
149+
"""
150+
Prepare permutation indices and the number of tokens for each expert.
151+
152+
Args:
153+
tokens_per_expert_group: number of tokens for each expert from all ranks.
154+
experts_per_rank: number of experts per rank.
155+
num_ranks: number of ranks.
156+
max_len: maximum length of the output index vector.
157+
alignment: alignment for each returned element in `m_sizes`.
158+
use_cpu: whether to use CPU implementation.
159+
use_optimized: whether to use optimized Triton implementation.
160+
block_size: block size for optimized implementation.
161+
162+
Returns:
163+
permuted_indices: Tensor of indices that map original token order to the expert-grouped order.
164+
m_sizes: aligned number of tokens for each expert (padded to alignment boundary).
165+
m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens.
166+
167+
Explanatory details:
168+
`tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
169+
From: | rank 0 | rank 1 |
170+
To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
171+
| 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
172+
"""
173+
# prefix sum to get start index of each expert (parallel scan kernel in future?)
174+
start_index_values = (
175+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
176+
)
177+
178+
# chunk sizes for each expert
179+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
180+
181+
# align the chunk sizes (cdiv)
182+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
183+
torch.int32
184+
)
185+
186+
# additional prefix sum to get write offset of each expert in permuted_indices
187+
# write offsets is per local expert, not global
188+
m_offsets = torch.cumsum(m_sizes, 0)
189+
write_offsets = m_offsets - m_sizes
190+
191+
# Select the implementation to use
192+
if use_cpu:
193+
permuted_indices = fill_indices_cpu(
194+
tokens_per_expert_group,
195+
start_index_values,
196+
write_offsets,
197+
experts_per_rank,
198+
num_ranks,
199+
max_len,
200+
)
201+
else:
202+
permuted_indices = fill_indices_wrapper(
203+
tokens_per_expert_group,
204+
start_index_values,
205+
write_offsets,
206+
experts_per_rank,
207+
num_ranks,
208+
max_len,
209+
)
210+
211+
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
212+
213+
214+
# Below is for testing only
215+
216+
217+
def simple_test():
218+
device = torch.device("cuda", 0)
219+
experts_per_rank = 4
220+
num_ranks = 4
221+
tokens_per_expert_group = torch.full(
222+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
223+
)
224+
max_len = 128
225+
alignment = 32
226+
# Use the GPU kernel
227+
permuted_indices_gpu, m_sizes, _ = generate_permute_indices(
228+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
229+
)
230+
# Use the CPU method
231+
permuted_indices_cpu, m_sizes, _ = generate_permute_indices(
232+
tokens_per_expert_group,
233+
experts_per_rank,
234+
num_ranks,
235+
max_len,
236+
alignment,
237+
use_cpu=True,
238+
)
239+
# Check that the results are the same
240+
241+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
242+
assert torch.equal(
243+
torch.remainder(m_sizes, alignment),
244+
torch.zeros(experts_per_rank, device=device),
245+
)
246+
# Print the results
247+
print(f"{permuted_indices_gpu=}, \n{permuted_indices_cpu=}")
248+
print(f"{m_sizes=}")
249+
print("Success")
250+
return True # assert would have failed meaning getting here is success.
251+
252+
253+
if __name__ == "__main__":
254+
simple_test()

0 commit comments

Comments
 (0)