Skip to content

Bug: Incorrect gradient accumulation steps calculation in multi-node training due to missing world size information #1927

@pratyushmaini

Description

@pratyushmaini

Bug description

There's a bug in the gradient accumulation calculation that affects multi-node training scenarios. The current implementation uses local device count instead of world size (total devices across all nodes), leading to incorrect gradient accumulation steps and training behavior.

Current Behavior
The gradient_accumulation_iters function calculates steps based on local batch size:

def gradient_accumulation_iters(self, devices: int) -> int:
    """Number of iterations between gradient synchronizations"""
    gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size
    return gradient_accumulation_iters

def batch_size(self, devices: int) -> int:
    """Number of samples between optimizer steps per data-parallel rank"""
    batch_size = self.global_batch_size // devices  # devices is local count only
    return batch_size

The devices parameter comes from torch.cuda.device_count(), which only returns local GPU count (e.g., 8) rather than total GPUs across all nodes (e.g., 8 * num_nodes = 128).

Impact
This causes:

  1. Incorrect gradient accumulation frequency (off by a factor of num_nodes)
  2. Mismatch between steps and iterations in training logs
  3. Example from logs:
Epoch 1 | iter 109472 step 13684 | loss train: 3.099, val: 3.075

Note the 8x difference between iterations and steps (109472/13684 ≈ 8)

Steps to Reproduce

  1. Configure multi-node training (e.g., 16 nodes, 8 GPUs each)
  2. Set global batch size and micro batch size
  3. Observe gradient accumulation steps and training logs

Expected Behavior
The calculation should use total world size (all devices across all nodes) instead of local device count:

def batch_size(self, devices: int) -> int:
    """Number of samples between optimizer steps per data-parallel rank"""
    batch_size = self.global_batch_size // fabric.world_size  # Use world_size instead of devices
    return batch_size

Proposed Solution
Use fabric.world_size instead of local devices count to properly account for all processes across nodes.

Additional Context

  • This issue likely wasn't caught earlier because the tutorials primarily use single-node setups (e.g., )
  • The bug becomes apparent only in multi-node training scenarios
  • This affects training convergence and effective learning rate in multi-node setups

Let me know if you'd like me to modify any part of this issue description before you post it.

What operating system are you using?

Linux

LitGPT Version




Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions