Skip to content

fix: overfit_batches uses same batch for train and val #20731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ overfit_batches
Uses this much data of the training & validation set.
If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it.

* When set to a value > 0, sequential sampling (no shuffling) is used
* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data

Useful for quickly debugging or trying to overfit on purpose.

.. testcode::
Expand All @@ -769,11 +772,11 @@ Useful for quickly debugging or trying to overfit on purpose.
# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches
# overfit on 10 consistent train batches & 10 consistent val batches
trainer = Trainer(overfit_batches=10)

plugins
^^^^^^^
# debug using a single consistent train batch and a single consistent val batch


:ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:

Expand Down Expand Up @@ -895,7 +898,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``.

# if 0 (default)
train_loader = model.train_dataloader()
# or if using data module: datamodule.train_dataloader()
# or if using data module: datamodule.train_dataloaders()
for epoch in epochs:
for batch in train_loader:
...
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,23 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
"""Resolve overfit batches by disabling shuffling.

When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent
batches across epochs. Training and validation use different sets of data.

"""
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return

rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)

updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def restore_env_variables():
"TF_GRPC_DEFAULT_OPTIONS",
"XLA_FLAGS",
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
# TensorFlow and TPU related variables
"TF2_BEHAVIOR",
"TPU_ML_PLATFORM",
"TPU_ML_PLATFORM_VERSION",
"LD_LIBRARY_PATH",
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
41 changes: 41 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches():
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False


def test_overfit_batches_same_batch_for_train_and_val(tmp_path):
"""Test that when overfit_batches=1, the same batch is used for both training and validation."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.train_batches = []
self.val_batches = []

def training_step(self, batch, batch_idx):
self.train_batches.append(batch)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.val_batches.append(batch)
return super().validation_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
overfit_batches=1,
check_val_every_n_epoch=1,
enable_model_summary=False,
)
trainer.fit(model)

# Verify that the same batch was used for both training and validation
assert len(model.train_batches) > 0
assert len(model.val_batches) > 0

# Compare the actual batch contents
train_batch = model.train_batches[0]
val_batch = model.val_batches[0]

# Check if the batches are identical
assert torch.equal(train_batch, val_batch), (
"Training and validation batches should be identical when overfit_batches=1"
)