Skip to content

Commit 4309419

Browse files
authored
1M+ context length (context parallel integration) (#2668)
1 parent 86f148b commit 4309419

File tree

5 files changed

+215
-32
lines changed

5 files changed

+215
-32
lines changed

recipes/full_finetune_distributed.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(self, cfg: DictConfig) -> None:
158158
raise ValueError(
159159
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
160160
)
161+
self.cp_degree = cfg.get("context_parallel_dim", 1)
161162
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
162163
data_replicate = cfg.get("data_parallel_replicate_dim", 1)
163164

@@ -166,6 +167,7 @@ def __init__(self, cfg: DictConfig) -> None:
166167
dp_replicate=data_replicate,
167168
dp_shard=data_shard,
168169
tp=self.tp_degree,
170+
cp=self.cp_degree,
169171
world_size=self.world_size,
170172
)
171173
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
@@ -603,6 +605,10 @@ def _setup_model(
603605
"FP8 training does not support tensor parallelism yet. "
604606
"This will be enabled in the near future."
605607
)
608+
if self.cp_degree > 1:
609+
raise ValueError(
610+
"Context Parallel for fp8 training is not currently supported"
611+
)
606612
model = convert_to_float8_training(model, self._fp8_recipe_name)
607613

608614
# Apply tensor parallelism to the model
@@ -665,6 +671,13 @@ def _setup_model(
665671
dp_mesh=self.world_mesh[dp_mesh_dim_names],
666672
)
667673

674+
# Define context manager for context parallelism
675+
self.context_parallel_manager = training.get_context_parallel_manager(
676+
enabled=self.cp_degree > 1,
677+
world_mesh=self.world_mesh,
678+
model=model,
679+
)
680+
668681
with training.set_default_dtype(self._dtype), self._device:
669682
for m in model.modules():
670683
# RoPE is not covered in state dict
@@ -797,7 +810,7 @@ def _setup_data(
797810
collate_fn,
798811
padding_idx=self._tokenizer.pad_id,
799812
ignore_idx=self._loss_fn.ignore_index,
800-
pad_to_multiple_of=self.tp_degree,
813+
pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor,
801814
)
802815
if not packed
803816
else padded_collate_packed
@@ -920,17 +933,17 @@ def train(self) -> None:
920933

921934
# Loss is normalized by default so we multiply by the number of tokens
922935
# This way we can normalize by the total number of tokens if we're accumulating gradients
923-
current_loss = self._loss_step(batch) * current_num_tokens
924-
running_loss += current_loss
925-
926-
# For optimizer in backward, we need to normalize before calling backward
927-
# This case and gradient accumulation are mutually exclusive
928-
if self._optimizer_in_bwd:
929-
torch.distributed.all_reduce(num_tokens)
930-
torch.distributed.all_reduce(running_loss)
931-
current_loss = current_loss * (self.dp_degree / num_tokens)
936+
with self.context_parallel_manager(list(batch.values())):
937+
current_loss = self._loss_step(batch) * current_num_tokens
938+
running_loss += current_loss
939+
# For optimizer in backward, we need to normalize before calling backward
940+
# This case and gradient accumulation are mutually exclusive
941+
if self._optimizer_in_bwd:
942+
torch.distributed.all_reduce(num_tokens)
943+
torch.distributed.all_reduce(running_loss)
944+
current_loss = current_loss * (self.dp_degree / num_tokens)
945+
current_loss.backward()
932946

933-
current_loss.backward()
934947
# Optimizer step (if not fused in backward call)
935948
if (idx + 1) % self._gradient_accumulation_steps == 0:
936949
if not self._optimizer_in_bwd:

recipes/qat_distributed.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(self, cfg: DictConfig) -> None:
168168
raise ValueError(
169169
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
170170
)
171+
self.cp_degree = cfg.get("context_parallel_dim", 1)
171172
data_shard = cfg.get("data_parallel_shard_dim", -1) # -1 means to infer
172173
data_replicate = cfg.get("data_parallel_replicate_dim", 1)
173174

@@ -176,6 +177,7 @@ def __init__(self, cfg: DictConfig) -> None:
176177
dp_replicate=data_replicate,
177178
dp_shard=data_shard,
178179
tp=self.tp_degree,
180+
cp=self.cp_degree,
179181
world_size=self.world_size,
180182
)
181183
self.world_mesh = self.parallel_dims.build_mesh(device_type=device_type)
@@ -670,6 +672,13 @@ def _setup_model(
670672
dp_mesh=self.world_mesh[dp_mesh_dim_names],
671673
)
672674

675+
# Define context manager for context parallelism
676+
self.context_parallel_manager = training.get_context_parallel_manager(
677+
enabled=self.cp_degree > 1,
678+
world_mesh=self.world_mesh,
679+
model=model,
680+
)
681+
673682
with training.set_default_dtype(self._dtype), self._device:
674683
for m in model.modules():
675684
# RoPE is not covered in state dict
@@ -802,7 +811,7 @@ def _setup_data(
802811
collate_fn,
803812
padding_idx=self._tokenizer.pad_id,
804813
ignore_idx=self._loss_fn.ignore_index,
805-
pad_to_multiple_of=self.tp_degree,
814+
pad_to_multiple_of=self.parallel_dims.min_seq_len_divisor,
806815
)
807816
if not packed
808817
else padded_collate_packed
@@ -925,17 +934,19 @@ def train(self) -> None:
925934

926935
# Loss is normalized by default so we multiply by the number of tokens
927936
# This way we can normalize by the total number of tokens if we're accumulating gradients
928-
current_loss = self._loss_step(batch) * current_num_tokens
929-
running_loss += current_loss
937+
with self.context_parallel_manager(list(batch.values())):
938+
current_loss = self._loss_step(batch) * current_num_tokens
939+
running_loss += current_loss
940+
941+
# For optimizer in backward, we need to normalize before calling backward
942+
# This case and gradient accumulation are mutually exclusive
943+
if self._optimizer_in_bwd:
944+
torch.distributed.all_reduce(num_tokens)
945+
torch.distributed.all_reduce(running_loss)
946+
current_loss = current_loss * (self.dp_degree / num_tokens)
930947

931-
# For optimizer in backward, we need to normalize before calling backward
932-
# This case and gradient accumulation are mutually exclusive
933-
if self._optimizer_in_bwd:
934-
torch.distributed.all_reduce(num_tokens)
935-
torch.distributed.all_reduce(running_loss)
936-
current_loss = current_loss * (self.dp_degree / num_tokens)
948+
current_loss.backward()
937949

938-
current_loss.backward()
939950
# Optimizer step (if not fused in backward call)
940951
if (idx + 1) % self._gradient_accumulation_steps == 0:
941952
if not self._optimizer_in_bwd:

torchtune/models/llama3_2/_model_builders.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
the llama3_2_1b model builder uses the llama3_2 component builder to create the
1616
Llama3.2 1B model.
1717
"""
18+
19+
1820
def llama3_2_1b(
1921
tie_word_embeddings: bool = True,
2022
) -> TransformerDecoder:
2123
"""
2224
Builder for creating a Llama3.2 model initialized w/ the default 1b parameter values.
23-
25+
2426
Args:
2527
tie_word_embeddings (bool): whether the model's input and output word embeddings should be tied.
2628
@@ -41,6 +43,8 @@ def llama3_2_1b(
4143
scale_factor=32,
4244
tie_word_embeddings=tie_word_embeddings,
4345
)
46+
47+
4448
def llama3_2_3b(
4549
tie_word_embeddings: bool = True,
4650
) -> TransformerDecoder:
@@ -67,6 +71,8 @@ def llama3_2_3b(
6771
scale_factor=32,
6872
tie_word_embeddings=tie_word_embeddings,
6973
)
74+
75+
7076
def lora_llama3_2_1b(
7177
lora_attn_modules: list[LORA_ATTN_MODULES],
7278
apply_lora_to_mlp: bool = False,
@@ -83,7 +89,7 @@ def lora_llama3_2_1b(
8389
The Llama3.2 defaults are the same as in :func:`~torchtune.models.llama3_2.llama3_2_1b`,
8490
while LoRA default params are based on
8591
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
86-
92+
8793
Args:
8894
lora_attn_modules (list[LORA_ATTN_MODULES]): list of which linear layers
8995
LoRA should be applied to in each self-attention block. Options are
@@ -125,6 +131,8 @@ def lora_llama3_2_1b(
125131
quantize_base=quantize_base,
126132
tie_word_embeddings=tie_word_embeddings,
127133
)
134+
135+
128136
def lora_llama3_2_3b(
129137
lora_attn_modules: list[LORA_ATTN_MODULES],
130138
apply_lora_to_mlp: bool = False,
@@ -161,7 +169,6 @@ def lora_llama3_2_3b(
161169
Returns:
162170
TransformerDecoder: Instantiation of Llama3.2 3B model with LoRA applied
163171
"""
164-
165172
return lora_llama3_2(
166173
lora_attn_modules=lora_attn_modules,
167174
apply_lora_to_mlp=apply_lora_to_mlp,
@@ -184,6 +191,8 @@ def lora_llama3_2_3b(
184191
quantize_base=quantize_base,
185192
tie_word_embeddings=tie_word_embeddings,
186193
)
194+
195+
187196
qlora_llama3_2_1b = partial(lora_llama3_2_1b, quantize_base=True)
188197
qlora_llama3_2_1b.__doc__ = """
189198
Builder for creating a Llama3.2 1B model with QLoRA enabled. Base model weights in linear layers

torchtune/training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchtune.training._compile import compile_loss, compile_model
1212
from torchtune.training._distributed import (
1313
gather_cpu_state_dict,
14+
get_context_parallel_manager,
1415
get_distributed_backend,
1516
get_full_optimizer_state_dict,
1617
get_shard_conditions,
@@ -145,4 +146,5 @@
145146
"get_distributed_backend",
146147
"disable_dropout",
147148
"DATALOADER_KEY",
149+
"get_context_parallel_manager",
148150
]

0 commit comments

Comments
 (0)