Skip to content

Commit 7b7c708

Browse files
FSDP example (#1019)
* Adding FSDP example * adding slurm cluster setup instruction * adding setup model func * added missing features * sumamrizatioon_dataset * Updates training and remove unnecessary imports * updtaing the wrapping policy * Added Zero2 sharding * updates from testing on clean machine * updates from clean machine, add requirements.txt * updates from clean machine * added SentencePiece * removed activation checkpointing and added check for bf16 * clean up * removing cluster setup * fix progress bars, update readme * update progress bars, readme * correct ordering for curr_val_loss evaluation and model save * clean up the dataset links * fixing the dataset links * updates from clean machine * reverting lastest unnecesary changes * moving to a new folder * adding FSDP to dist folder * updates to address comments * adding utils and configs to make the code modular * clean up --------- Co-authored-by: lessw2020 <[email protected]>
1 parent 79ef786 commit 7b7c708

18 files changed

+949
-0
lines changed

distributed/FSDP/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__pycache__/
2+
*.pt
3+
*.csv

distributed/FSDP/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## FSDP T5
2+
3+
To run the T5 example with FSDP for text summarization:
4+
5+
## Get the wikihow dataset
6+
```bash
7+
8+
sh download_dataset.sh
9+
10+
```
11+
12+
## Install the requirements:
13+
~~~
14+
pip install -r requirements.txt
15+
~~~
16+
## Ensure you are running a recent version of PyTorch:
17+
see https://pytorch.org to install at least 1.12 and ideally a current nightly build.
18+
19+
Start the training with Torchrun (adjust nproc_per_node to your GPU count):
20+
21+
```
22+
torchrun --nnodes 1 --nproc_per_node 4 T5_training.py
23+
24+
```

distributed/FSDP/T5_training.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import os
2+
import argparse
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torch.optim as optim
7+
from transformers import AutoTokenizer, GPT2TokenizerFast
8+
from transformers import T5Tokenizer, T5ForConditionalGeneration
9+
import functools
10+
from torch.optim.lr_scheduler import StepLR
11+
import torch.nn.functional as F
12+
import torch.distributed as dist
13+
import torch.multiprocessing as mp
14+
from torch.nn.parallel import DistributedDataParallel as DDP
15+
from torch.utils.data.distributed import DistributedSampler
16+
from transformers.models.t5.modeling_t5 import T5Block
17+
18+
from torch.distributed.fsdp import (
19+
FullyShardedDataParallel as FSDP,
20+
CPUOffload,
21+
MixedPrecision,
22+
BackwardPrefetch,
23+
ShardingStrategy,
24+
FullStateDictConfig,
25+
StateDictType,
26+
)
27+
28+
from functools import partial
29+
from torch.utils.data import DataLoader
30+
from pathlib import Path
31+
from summarization_dataset import *
32+
import policies
33+
import model_checkpointing
34+
from configs import fsdp_config, train_config
35+
from utils import (bfloat_support, setup,
36+
cleanup, get_date_of_run,
37+
format_metrics_to_gb,
38+
train,validation,setup_model)
39+
from transformers.models.t5.modeling_t5 import T5Block
40+
from typing import Type
41+
import time
42+
import tqdm
43+
from datetime import datetime
44+
45+
46+
def get_policies(cfg, rank):
47+
48+
"""establish current policies for mixed precision and fsdp wrapping"""
49+
50+
mixed_precision_policy = None
51+
wrapping_policy = None
52+
53+
# mixed precision -----
54+
if cfg.mixed_precision:
55+
bfloat_available = bfloat_support()
56+
if bfloat_available and not cfg.use_fp16:
57+
mixed_precision_policy = policies.bfSixteen
58+
if rank == 0:
59+
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
60+
elif cfg.use_fp16:
61+
mixed_precision_policy = policies.fpSixteen
62+
if rank == 0:
63+
print(f"FP16 enabled. ")
64+
else:
65+
# mixed_precision_policy = policies.fpSixteen
66+
print(
67+
f"bFloat16 support not present. Will use FP32, and not mixed precision"
68+
)
69+
70+
wrapping_policy = policies.get_t5_wrapper()
71+
72+
return mixed_precision_policy, wrapping_policy
73+
74+
75+
def fsdp_main(args):
76+
77+
model, tokenizer = setup_model(train_config.model_name)
78+
79+
local_rank = int(os.environ['LOCAL_RANK'])
80+
rank = int(os.environ['RANK'])
81+
world_size = int(os.environ['WORLD_SIZE'])
82+
83+
84+
dataset = load_dataset('wikihow', 'all', data_dir='data/')
85+
print(dataset.keys())
86+
print("Size of train dataset: ", dataset['train'].shape)
87+
print("Size of Validation dataset: ", dataset['validation'].shape)
88+
89+
90+
#wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
91+
train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
92+
val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)
93+
94+
sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
95+
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
96+
97+
setup()
98+
99+
100+
train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
101+
test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
102+
cuda_kwargs = {'num_workers': 2,
103+
'pin_memory': True,
104+
'shuffle': False}
105+
train_kwargs.update(cuda_kwargs)
106+
test_kwargs.update(cuda_kwargs)
107+
108+
train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
109+
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
110+
111+
torch.cuda.set_device(local_rank)
112+
113+
# Set up FSDP parameters
114+
mixed_precision_policy, t5_auto_wrap_policy = get_policies(train_config, rank)
115+
116+
# Apply FSDP wrapping to the model
117+
model = FSDP(model,
118+
auto_wrap_policy=t5_auto_wrap_policy,
119+
mixed_precision=mixed_precision_policy,
120+
sharding_strategy=fsdp_config.sharding_strategy,
121+
device_id=torch.cuda.current_device(),
122+
limit_all_gathers=fsdp_config.limit_all_gathers)
123+
124+
if fsdp_config.fsdp_activation_checkpointing:
125+
policies.apply_fsdp_checkpointing(model)
126+
127+
# Set up optimizer and scheduler
128+
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr)
129+
130+
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
131+
best_val_loss = float("inf")
132+
curr_val_loss = float("inf")
133+
file_save_name = "T5-model-"
134+
135+
if rank == 0:
136+
time_of_run = get_date_of_run()
137+
dur = []
138+
train_acc_tracking = []
139+
val_acc_tracking = []
140+
training_start_time = time.time()
141+
142+
if rank == 0 and args.track_memory:
143+
mem_alloc_tracker = []
144+
mem_reserved_tracker = []
145+
146+
for epoch in range(1, args.epochs + 1):
147+
t0 = time.time()
148+
train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
149+
if args.run_validation:
150+
curr_val_loss = validation(model, rank, world_size, val_loader)
151+
scheduler.step()
152+
153+
if rank == 0:
154+
155+
print(f"--> epoch {epoch} completed...entering save and stats zone")
156+
157+
dur.append(time.time() - t0)
158+
train_acc_tracking.append(train_accuracy.item())
159+
160+
if args.run_validation:
161+
val_acc_tracking.append(curr_val_loss.item())
162+
163+
if args.track_memory:
164+
mem_alloc_tracker.append(
165+
format_metrics_to_gb(torch.cuda.memory_allocated())
166+
)
167+
mem_reserved_tracker.append(
168+
format_metrics_to_gb(torch.cuda.memory_reserved())
169+
)
170+
171+
if train_config.save_model and curr_val_loss < best_val_loss:
172+
173+
if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
174+
model_checkpointing.save_model_checkpoint(
175+
model, optimizer, rank, fsdp_config, epoch=1
176+
)
177+
elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
178+
model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config)
179+
if fsdp_config.save_optimizer:
180+
model_checkpointing.save_model_and_optimizer_sharded(model, rank, fsdp_config, optim=optimizer)
181+
182+
if fsdp_config.save_optimizer:
183+
model_checkpointing.save_optimizer_checkpoint(
184+
model, optimizer, rank, fsdp_config, epoch=1
185+
)
186+
if curr_val_loss < best_val_loss:
187+
188+
best_val_loss = curr_val_loss
189+
if rank==0:
190+
print(f"-->>>> New Val Loss Record: {best_val_loss}")
191+
192+
dist.barrier()
193+
cleanup()
194+
195+
196+
if __name__ == '__main__':
197+
# Training settings
198+
parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
199+
parser.add_argument('--batch-size', type=int, default=4, metavar='N',
200+
help='input batch size for training (default: 64)')
201+
parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
202+
help='input batch size for testing (default: 1000)')
203+
parser.add_argument('--epochs', type=int, default=2, metavar='N',
204+
help='number of epochs to train (default: 3)')
205+
parser.add_argument('--seed', type=int, default=1, metavar='S',
206+
help='random seed (default: 1)')
207+
parser.add_argument('--track_memory', action='store_false', default=True,
208+
help='track the gpu memory')
209+
parser.add_argument('--run_validation', action='store_false', default=True,
210+
help='running the validation')
211+
args = parser.parse_args()
212+
213+
torch.manual_seed(args.seed)
214+
215+
fsdp_main(args)

distributed/FSDP/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .fsdp import fsdp_config
2+
from .training import train_config

distributed/FSDP/configs/fsdp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass, field
2+
from typing import ClassVar
3+
from torch.distributed.fsdp import ShardingStrategy
4+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
5+
6+
@dataclass
7+
class fsdp_config:
8+
mixed_precision: bool=True
9+
use_fp16: bool=False
10+
seed: int=42
11+
fsdp_activation_checkpointing: bool=True
12+
limit_all_gathers: bool=True
13+
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD #HYBRID_SHARD, SHARD_GRAD_OP
14+
checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # alternatively can use SHARDED_STATE_DICT to avoid OOMs
15+
save_optimizer: bool=False
16+
17+
18+
19+

distributed/FSDP/configs/training.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
from typing import ClassVar
3+
4+
5+
@dataclass
6+
class train_config:
7+
model_name: str="t5-base"
8+
run_validation: bool=True
9+
batch_size_training: int=4
10+
num_workers_dataloader: int=2
11+
lr: float=0.002
12+
weight_decay: float=0.0
13+
gamma: float= 0.85
14+
use_fp16: bool=False
15+
mixed_precision: bool=True
16+
save_model: bool=False
17+
18+
19+

distributed/FSDP/download_dataset.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
3+
# Create the "data" folder if it doesn't exist
4+
mkdir -p data
5+
6+
# Download the files into the "data" folder
7+
wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowAll.csv
8+
wget -P data https://public-nlp-datasets.s3.us-west-2.amazonaws.com/wikihowSep.csv
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .checkpoint_handler import (
2+
load_model_checkpoint,
3+
save_model_checkpoint,
4+
save_distributed_model_checkpoint,
5+
load_distributed_model_checkpoint,
6+
load_optimizer_checkpoint,
7+
save_optimizer_checkpoint,
8+
save_model_and_optimizer_sharded,
9+
load_model_sharded,
10+
)

0 commit comments

Comments
 (0)