Skip to content
Discussion options

You must be logged in to vote

you can try:

factor = # factor by which you want to increase the accumulate_grad_batches
N = 
max_epochs = 
init_acc_grad_batches = 
acc_grad_batches = {ep: init_acc_grad_batches + (i)*factor for i, ep in enumerate(range(0, max_epochs, N))}
trainer = Trainer(accumulate_grad_batches=acc_grad_batches, max_epochs=max_epochs, ...)

accumulate_grad_batches can take input as a Dict where the key represents the epoch where this value will be changed and the value represents the accumulate_grad_batches to use.

So if it's {0: 1, 3: 2, 5: 4}, then
epoch [1-4) -> 1
epoch [4-6) -> 3
epoch [6-max_epochs] -> 4

looks like the epoch(key) here should be zero-indexed, which might be a small bug here.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by tchaton
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment