File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change 14
14
local_rank = - 1
15
15
global_rank = - 1
16
16
num_epochs = 100
17
+ step_number = 0
18
+ last_step = False
17
19
18
20
class MyModel :
19
21
pass
@@ -40,10 +42,15 @@ def train():
40
42
41
43
for epoch in range (num_epochs ):
42
44
for data , labels in data_loader :
43
- loss = loss_fn (model (data ), labels ) # Forward step
44
- loss .backward () # Backward step + gradient synchronization
45
- optimizer .step () # Update weights
46
- optimizer .zero_grad () # Reset gradients to zero
45
+ if (step_number + 1 ) % 100 != 0 and not last_step : # Accumulate gradients for 100 steps
46
+ with model .no_sync (): # Disable gradient synchronization
47
+ loss = loss_fn (model (data ), labels ) # Forward step
48
+ loss .backward () # Backward step + gradient ACCUMULATION
49
+ else :
50
+ loss = loss_fn (model (data ), labels ) # Forward step
51
+ loss .backward () # Backward step + gradient SYNCHRONIZATION
52
+ optimizer .step () # Update weights
53
+ optimizer .zero_grad () # Reset gradients to zero
47
54
48
55
if global_rank == 0 :
49
56
collect_statistics () # W&B, etc.
You can’t perform that action at this time.
0 commit comments