Skip to content

(fix) remove sampler_is_batch_sampler code in prepare_data_loader(..) #3469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ def get_sampler(self):
return get_sampler(self)

def set_sampler(self, sampler):
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
if sampler_is_batch_sampler:
if isinstance(self.sampler, BatchSampler):
self.sampler.sampler = sampler
else:
self.batch_sampler.sampler = sampler
Expand Down Expand Up @@ -958,8 +957,7 @@ def get_sampler(self):
return get_sampler(self)

def set_sampler(self, sampler):
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
if sampler_is_batch_sampler:
if isinstance(self.sampler, BatchSampler):
self.sampler.sampler = sampler
else:
self.batch_sampler.sampler = sampler
Expand All @@ -977,10 +975,8 @@ def get_sampler(dataloader):
Returns:
`torch.utils.data.Sampler`: The sampler associated to the dataloader
"""
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
if sampler_is_batch_sampler:
Comment on lines -980 to -981
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the isinstance(dataloader.sampler, BatchSampler) check

sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.sampler, "sampler", None)
if not sampler:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
return sampler

Expand Down Expand Up @@ -1155,11 +1151,18 @@ def prepare_data_loader(

new_dataset = dataloader.dataset
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
if isinstance(dataloader.sampler, BatchSampler):
logger.warning(
"BatchSampler was passed to sampler argument."
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
)

Comment on lines +1154 to +1160
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning is nice, maybe we should add a deprecation message also saying that we won't allow passing BatchSampler to sampler anymore ?

new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)

synchronized_generator = None
sampler = dataloader.sampler

sampler = get_sampler(dataloader)
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
Expand Down Expand Up @@ -1208,9 +1211,8 @@ def prepare_data_loader(
seed = int(torch.empty((), dtype=torch.int64).random_().item())
sampler.generator.manual_seed(seed)
synchronized_generator = sampler.generator
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = BatchSamplerShard(
batch_sampler,
dataloader.batch_sampler,
num_processes=num_processes,
process_index=process_index,
split_batches=split_batches,
Expand Down Expand Up @@ -1254,19 +1256,6 @@ def prepare_data_loader(
torch_device_mesh=torch_device_mesh,
**kwargs,
)
elif sampler_is_batch_sampler:
dataloader = DataLoaderShard(
new_dataset,
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
sampler=new_batch_sampler,
batch_size=dataloader.batch_size,
rng_types=rng_types,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
dataloader = DataLoaderShard(
new_dataset,
Expand Down Expand Up @@ -1361,13 +1350,17 @@ def skip_first_batches(dataloader, num_batches=0):
dataloader = dataloader.dataloader

dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataloader.sampler, BatchSampler):
logger.warning(
"BatchSampler was passed to sampler argument."
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
)

if isinstance(dataset, IterableDataset):
new_batch_sampler = None
else:
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
new_batch_sampler = SkipBatchSampler(dataloader.batch_sampler, skip_batches=num_batches)

# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
Expand Down Expand Up @@ -1404,9 +1397,6 @@ def skip_first_batches(dataloader, num_batches=0):
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
kwargs["skip_batches"] = num_batches
elif sampler_is_batch_sampler:
kwargs["sampler"] = new_batch_sampler
kwargs["batch_size"] = dataloader.batch_size
else:
kwargs["batch_sampler"] = new_batch_sampler
dataloader = DataLoaderShard(
Expand Down
Loading