Skip to content

fix: overfit_batches uses same batch for train and val #20731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ overfit_batches
Uses this much data of the training & validation set.
If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it.

* When set to exactly 1, the same batch is used for both training and validation steps, which is useful for debugging model implementation
* For other values, sequential sampling (no shuffling) is used

Copy link
Contributor

@adosar adosar Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still insist that the case overfit_batches=1 doesn't need any special handling. Model debugging, overfitting and sanity checks can all be achieved with the current implementation, i.e. training and validation use different batches even for overfit_batches=1.

Useful for quickly debugging or trying to overfit on purpose.

.. testcode::
Expand All @@ -769,9 +772,13 @@ Useful for quickly debugging or trying to overfit on purpose.
# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
# overfit on 10 (same) train batches & 10 (same) val batches
trainer = Trainer(overfit_batches=10)

# debug by training and validating on exactly the same single batch
# (useful for verifying model implementation)
trainer = Trainer(overfit_batches=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugins
^^^^^^^

Expand Down
53 changes: 49 additions & 4 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,19 +244,64 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
"""Resolve overfit batches by ensuring the same batch is used for both training and validation."""
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return

rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
]

# Get the first batch from the training dataloader
first_batch = None
if mode == RunningStage.TRAINING:
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
first_batch = next(iter(dl))
break

# Create new dataloaders with SequentialSampler
updated = []
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
if mode == RunningStage.VALIDATING and first_batch is not None:
# For validation, create a custom sampler that always returns the first batch
class SingleBatchSampler(Sampler):
def __init__(self, batch):
self.batch = batch

def __iter__(self):
yield self.batch

def __len__(self):
return 1

sampler = SingleBatchSampler(first_batch)
else:
sampler = SequentialSampler(dl.dataset)

# Create a new dataloader with the new sampler
dl = DataLoader(
dataset=dl.dataset,
batch_size=dl.batch_size,
sampler=sampler,
num_workers=dl.num_workers,
collate_fn=dl.collate_fn,
pin_memory=dl.pin_memory,
drop_last=dl.drop_last,
timeout=dl.timeout,
worker_init_fn=dl.worker_init_fn,
multiprocessing_context=dl.multiprocessing_context,
generator=dl.generator,
prefetch_factor=dl.prefetch_factor,
persistent_workers=dl.persistent_workers,
)
updated.append(dl)

combined_loader.flattened = updated


Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def restore_env_variables():
"TF_GRPC_DEFAULT_OPTIONS",
"XLA_FLAGS",
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
# TensorFlow and TPU related variables
"TF2_BEHAVIOR",
"TPU_ML_PLATFORM",
"TPU_ML_PLATFORM_VERSION",
"LD_LIBRARY_PATH",
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
41 changes: 41 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches():
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False


def test_overfit_batches_same_batch_for_train_and_val(tmp_path):
"""Test that when overfit_batches=1, the same batch is used for both training and validation."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.train_batches = []
self.val_batches = []

def training_step(self, batch, batch_idx):
self.train_batches.append(batch)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.val_batches.append(batch)
return super().validation_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
overfit_batches=1,
check_val_every_n_epoch=1,
enable_model_summary=False,
)
trainer.fit(model)

# Verify that the same batch was used for both training and validation
assert len(model.train_batches) > 0
assert len(model.val_batches) > 0

# Compare the actual batch contents
train_batch = model.train_batches[0]
val_batch = model.val_batches[0]

# Check if the batches are identical
assert torch.equal(train_batch, val_batch), (
"Training and validation batches should be identical when overfit_batches=1"
)