Skip to content

Commit 1217c06

Browse files
committed
added no_sync block
1 parent ab8860f commit 1217c06

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

notes/ddp_template.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
local_rank = -1
1515
global_rank = -1
1616
num_epochs = 100
17+
step_number = 0
18+
last_step = False
1719

1820
class MyModel:
1921
pass
@@ -40,10 +42,15 @@ def train():
4042

4143
for epoch in range(num_epochs):
4244
for data, labels in data_loader:
43-
loss = loss_fn(model(data), labels) # Forward step
44-
loss.backward() # Backward step + gradient synchronization
45-
optimizer.step() # Update weights
46-
optimizer.zero_grad() # Reset gradients to zero
45+
if (step_number + 1) % 100 != 0 and not last_step: # Accumulate gradients for 100 steps
46+
with model.no_sync(): # Disable gradient synchronization
47+
loss = loss_fn(model(data), labels) # Forward step
48+
loss.backward() # Backward step + gradient ACCUMULATION
49+
else:
50+
loss = loss_fn(model(data), labels) # Forward step
51+
loss.backward() # Backward step + gradient SYNCHRONIZATION
52+
optimizer.step() # Update weights
53+
optimizer.zero_grad() # Reset gradients to zero
4754

4855
if global_rank == 0:
4956
collect_statistics() # W&B, etc.

0 commit comments

Comments
 (0)