Skip to content

fix: overfit_batches uses same batch for train and val #20731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,19 +244,66 @@ def _get_distributed_sampler(


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
"""Resolve overfit batches by ensuring the same batch is used for both training and validation."""
all_have_sequential_sampler = all(
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
)
if all_have_sequential_sampler:
return

rank_zero_warn(
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)
updated = [
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
for dl in combined_loader.flattened
]

# Get the first batch from the training dataloader
first_batch = None
if mode == RunningStage.TRAINING:
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
first_batch = next(iter(dl))
break

# Create new dataloaders with SequentialSampler
updated = []
for dl in combined_loader.flattened:
if hasattr(dl, "dataset"):
if mode == RunningStage.VALIDATING and first_batch is not None:
# For validation, create a custom sampler that always returns the first batch
class SingleBatchSampler(Sampler):
def __init__(self, batch):
self.batch = batch

def __iter__(self):
yield self.batch

def __len__(self):
return 1

sampler = SingleBatchSampler(first_batch)
else:
sampler = SequentialSampler(dl.dataset)

# Create a new dataloader with the new sampler
new_dl = DataLoader(
dataset=dl.dataset,
batch_size=dl.batch_size,
sampler=sampler,
num_workers=dl.num_workers,
collate_fn=dl.collate_fn,
pin_memory=dl.pin_memory,
drop_last=dl.drop_last,
timeout=dl.timeout,
worker_init_fn=dl.worker_init_fn,
multiprocessing_context=dl.multiprocessing_context,
generator=dl.generator,
prefetch_factor=dl.prefetch_factor,
persistent_workers=dl.persistent_workers,
)
updated.append(new_dl)
else:
updated.append(dl)

combined_loader.flattened = updated


Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def restore_env_variables():
"TF_GRPC_DEFAULT_OPTIONS",
"XLA_FLAGS",
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
# TensorFlow and TPU related variables
"TF2_BEHAVIOR",
"TPU_ML_PLATFORM",
"TPU_ML_PLATFORM_VERSION",
"LD_LIBRARY_PATH",
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
41 changes: 41 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches():
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False


def test_overfit_batches_same_batch_for_train_and_val(tmp_path):
"""Test that when overfit_batches=1, the same batch is used for both training and validation."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.train_batches = []
self.val_batches = []

def training_step(self, batch, batch_idx):
self.train_batches.append(batch)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.val_batches.append(batch)
return super().validation_step(batch, batch_idx)

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
overfit_batches=1,
check_val_every_n_epoch=1,
enable_model_summary=False,
)
trainer.fit(model)

# Verify that the same batch was used for both training and validation
assert len(model.train_batches) > 0
assert len(model.val_batches) > 0

# Compare the actual batch contents
train_batch = model.train_batches[0]
val_batch = model.val_batches[0]

# Check if the batches are identical
assert torch.equal(train_batch, val_batch), (
"Training and validation batches should be identical when overfit_batches=1"
)
Loading