Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions examples/fsdp2/fsdp2_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example of training with Context Parallel using FSDP2 via Accelerate.
This example demonstrates how to use Accelerate's context_parallel feature for efficient long sequence training.
"""

import argparse

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM

from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin, TorchTensorParallelPlugin, set_seed
from utils import PerformanceTracker, create_collate_fn, get_dataset, setup_tokenizer


MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--apply-fsdp", action="store_true")
parser.add_argument("--tp-size", type=int, default=2)
parser.add_argument("--sequence-length", type=int, default=4096)
return parser.parse_args()


def main():
"""
Main function to train the model.
"""
args = parse_args()

set_seed(42)

model_kwargs = {}
accelerator_kwargs = {}

if args.apply_fsdp:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin

if args.tp_size > 1 and not args.apply_fsdp:
if args.tp_size != dist.get_world_size():
raise ValueError(
f"TP size {args.tp_size} does not match world size {dist.get_world_size()}. Either set TP size to {dist.get_world_size()} or apply FSDP2."
)

if args.tp_size > 1:
accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)

accelerator = Accelerator(
log_with=["wandb"],
mixed_precision="bf16",
**accelerator_kwargs,
)
accelerator.init_trackers(
project_name="fsdp2-tp",
config={"apply_fsdp": args.apply_fsdp, "tp_size": args.tp_size},
)

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
use_cache=False,
device_map="auto" if args.tp_size > 1 else None,
device_mesh=accelerator.torch_device_mesh if args.tp_size > 1 else None,
**model_kwargs,
)

tokenizer = setup_tokenizer(MODEL_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

model, optimizer = accelerator.prepare(model, optimizer)

dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)

model.train()

total_num_steps = min(1000, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)

for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break

outputs = model(**batch)
loss = outputs.loss

accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()

dist.all_reduce(loss, op=dist.ReduceOp.AVG)

batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)

print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}

if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}"

if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)

accelerator.log(log_metrics)

accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")


if __name__ == "__main__":
main()
173 changes: 173 additions & 0 deletions examples/fsdp2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Common utilities for FSDP2 examples.
"""

import time

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import Accelerator


def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
"""
Load and prepare TinyStories dataset.

Args:
accelerator (Accelerator): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset

Returns:
Dataset: Packed dataset
"""
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")

def tokenize_function(examples):
tokenized_batch = tokenizer(
examples["text"],
padding=False,
truncation=True,
max_length=seq_len,
return_tensors=None,
)
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch

with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

def create_packed_sequences(examples):
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)

num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []

for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])

return {"input_ids": packed_input_ids, "labels": packed_labels}

with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000,
)

return packed_dataset.shuffle(seed=42)


def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
"""
Get the number of flops per token for the model.

Args:
model (AutoModelForCausalLM): Model to get the flops for
seq_len (int): Sequence length
"""
cfg = model.config
head_dim = cfg.hidden_size // cfg.num_attention_heads

# MLP: 3 matmuls
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size

# Attn (w/o dotproduct)
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)

# attn (dotproduct) - this scales quadratically with sequence length
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len

# we also ignore embeddings and layernorms, etc
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers


def create_collate_fn():
"""Create a collate function for batching."""

def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}

return collate_fn


class PerformanceTracker:
"""Track training performance metrics."""

def __init__(self, warmup_steps: int = 10):
self.warmup_steps = warmup_steps
self.reset()

def reset(self):
"""Reset all tracking variables."""
self.start_time = None
self.num_tokens = 0
self.is_in_warmup = True
self.step_count = 0

def step(self, batch_tokens: int) -> dict:
"""
Update performance tracking with a new step.

Args:
batch_tokens (int): Number of tokens in current batch

Returns:
dict: Performance metrics if past warmup, empty dict otherwise
"""
self.step_count += 1

if self.step_count == self.warmup_steps:
self.start_time = time.perf_counter()
self.num_tokens = 0
self.is_in_warmup = False
return {"warmup_completed": True}

if not self.is_in_warmup and self.start_time is not None:
self.num_tokens += batch_tokens
total_time = time.perf_counter() - self.start_time
steps_from_warmup = self.step_count - self.warmup_steps

if total_time > 0 and steps_from_warmup > 0:
return {
"tokens_per_second": self.num_tokens / total_time,
"steps_per_second": steps_from_warmup / total_time,
"total_tokens": self.num_tokens,
"total_time": total_time,
}

return {}


def setup_tokenizer(model_id: str) -> AutoTokenizer:
"""Setup tokenizer with proper padding token."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
Loading
Loading