Skip to content

Commit ab8860f

Browse files
committed
added template to transform any project into ddp
1 parent ad1d4e6 commit ab8860f

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

notes/ddp_template.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
## How to convert any PyTorch project into a DistributedDataParallel project
2+
3+
# Distributed training
4+
from torch.utils.data.distributed import DistributedSampler
5+
from torch.nn.parallel import DistributedDataParallel
6+
from torch.distributed import init_process_group, destroy_process_group
7+
8+
import os
9+
import torch
10+
from torch.utils.data import Dataset, DataLoader, random_split
11+
12+
# Dummy variables to make Pylance happy :D
13+
train_dataset = None
14+
local_rank = -1
15+
global_rank = -1
16+
num_epochs = 100
17+
18+
class MyModel:
19+
pass
20+
21+
def initialize_services():
22+
pass
23+
24+
def collect_statistics():
25+
pass
26+
27+
def train():
28+
if global_rank == 0:
29+
initialize_services() # W&B, etc.
30+
31+
data_loader = DataLoader(train_dataset, shuffle=False, sampler=DistributedSampler(train_dataset, shuffle=True))
32+
model = MyModel()
33+
if os.path.exists('latest_checkpoint.pth'): # Load latest checkpoint
34+
# Also load optimizer state and other variables needed to restore the training state
35+
model.load_state_dict(torch.load('latest_checkpoint.pth'))
36+
37+
model = DistributedDataParallel(model, device_ids=[local_rank])
38+
optimizer = torch.optim.Adam(model.parameters(), lr=10e-4, eps=1e-9)
39+
loss_fn = torch.nn.CrossEntropyLoss()
40+
41+
for epoch in range(num_epochs):
42+
for data, labels in data_loader:
43+
loss = loss_fn(model(data), labels) # Forward step
44+
loss.backward() # Backward step + gradient synchronization
45+
optimizer.step() # Update weights
46+
optimizer.zero_grad() # Reset gradients to zero
47+
48+
if global_rank == 0:
49+
collect_statistics() # W&B, etc.
50+
51+
if global_rank == 0: # Only save on rank 0
52+
# Also save the optimizer state and other variables needed to restore the training state
53+
torch.save(model.state_dict(), 'latest_checkpoint.pth')
54+
55+
56+
if __name__ == '__main__':
57+
local_rank = int(os.environ['LOCAL_RANK'])
58+
global_rank = int(os.environ['RANK'])
59+
60+
init_process_group(backend='nccl')
61+
torch.cuda.set_device(local_rank) # Set the device to local rank
62+
63+
train()
64+
65+
destroy_process_group()
66+

0 commit comments

Comments
 (0)