Skip to content

Commit 427b3ca

Browse files
committed
(fix) remove sampler_is_batch_sampler code
1 parent 9642a1a commit 427b3ca

File tree

1 file changed

+24
-40
lines changed

1 file changed

+24
-40
lines changed

src/accelerate/data_loader.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from packaging import version
22-
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
22+
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler
2323

2424
from .logging import get_logger
2525
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
@@ -631,13 +631,10 @@ def get_sampler(self):
631631
return get_sampler(self)
632632

633633
def set_sampler(self, sampler):
634-
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
635-
if sampler_is_batch_sampler:
634+
if isinstance(sampler, BatchSampler):
635+
self.sampler.batch_sampler = sampler
636+
elif isinstance(sampler, Sampler):
636637
self.sampler.sampler = sampler
637-
else:
638-
self.batch_sampler.sampler = sampler
639-
if hasattr(self.batch_sampler, "batch_sampler"):
640-
self.batch_sampler.batch_sampler.sampler = sampler
641638

642639

643640
if is_torch_xla_available():
@@ -958,13 +955,12 @@ def get_sampler(self):
958955
return get_sampler(self)
959956

960957
def set_sampler(self, sampler):
961-
sampler_is_batch_sampler = isinstance(self.sampler, BatchSampler)
962-
if sampler_is_batch_sampler:
958+
if isinstance(sampler, BatchSampler):
959+
self.sampler.batch_sampler = sampler
960+
elif isinstance(sampler, Sampler):
963961
self.sampler.sampler = sampler
964962
else:
965-
self.batch_sampler.sampler = sampler
966-
if hasattr(self.batch_sampler, "batch_sampler"):
967-
self.batch_sampler.batch_sampler.sampler = sampler
963+
raise ValueError(f"{sampler} must be of type torch.utills.data.Sampler or torch.utils.data.BatchSampler")
968964

969965

970966
def get_sampler(dataloader):
@@ -977,10 +973,8 @@ def get_sampler(dataloader):
977973
Returns:
978974
`torch.utils.data.Sampler`: The sampler associated to the dataloader
979975
"""
980-
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
981-
if sampler_is_batch_sampler:
982-
sampler = getattr(dataloader.sampler, "sampler", None)
983-
else:
976+
sampler = getattr(dataloader.sampler, "sampler", None)
977+
if not sampler:
984978
sampler = getattr(dataloader.batch_sampler, "sampler", None)
985979
return sampler
986980

@@ -1155,11 +1149,16 @@ def prepare_data_loader(
11551149

11561150
new_dataset = dataloader.dataset
11571151
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1152+
if isinstance(dataloader.sampler, BatchSampler):
1153+
raise ValueError(
1154+
"Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1155+
)
1156+
11581157
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
1159-
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1158+
11601159
synchronized_generator = None
1160+
sampler = dataloader.sampler
11611161

1162-
sampler = get_sampler(dataloader)
11631162
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
11641163
# When iterating through the dataloader during distributed processes
11651164
# we want to ensure that on each process we are iterating through the same
@@ -1208,9 +1207,8 @@ def prepare_data_loader(
12081207
seed = int(torch.empty((), dtype=torch.int64).random_().item())
12091208
sampler.generator.manual_seed(seed)
12101209
synchronized_generator = sampler.generator
1211-
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
12121210
new_batch_sampler = BatchSamplerShard(
1213-
batch_sampler,
1211+
dataloader.batch_sampler,
12141212
num_processes=num_processes,
12151213
process_index=process_index,
12161214
split_batches=split_batches,
@@ -1254,19 +1252,6 @@ def prepare_data_loader(
12541252
torch_device_mesh=torch_device_mesh,
12551253
**kwargs,
12561254
)
1257-
elif sampler_is_batch_sampler:
1258-
dataloader = DataLoaderShard(
1259-
new_dataset,
1260-
device=device if put_on_device and state.distributed_type != DistributedType.XLA else None,
1261-
sampler=new_batch_sampler,
1262-
batch_size=dataloader.batch_size,
1263-
rng_types=rng_types,
1264-
_drop_last=dataloader.drop_last,
1265-
_non_blocking=non_blocking,
1266-
synchronized_generator=synchronized_generator,
1267-
use_stateful_dataloader=use_stateful_dataloader,
1268-
**kwargs,
1269-
)
12701255
else:
12711256
dataloader = DataLoaderShard(
12721257
new_dataset,
@@ -1361,13 +1346,15 @@ def skip_first_batches(dataloader, num_batches=0):
13611346
dataloader = dataloader.dataloader
13621347

13631348
dataset = dataloader.dataset
1364-
sampler_is_batch_sampler = False
1349+
if isinstance(dataloader.sampler, BatchSampler):
1350+
raise ValueError(
1351+
"Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1352+
)
1353+
13651354
if isinstance(dataset, IterableDataset):
13661355
new_batch_sampler = None
13671356
else:
1368-
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
1369-
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
1370-
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
1357+
new_batch_sampler = SkipBatchSampler(dataloader.batch_sampler, skip_batches=num_batches)
13711358

13721359
# We ignore all of those since they are all dealt with by our new_batch_sampler
13731360
ignore_kwargs = [
@@ -1404,9 +1391,6 @@ def skip_first_batches(dataloader, num_batches=0):
14041391
if new_batch_sampler is None:
14051392
# Need to manually skip batches in the dataloader
14061393
kwargs["skip_batches"] = num_batches
1407-
elif sampler_is_batch_sampler:
1408-
kwargs["sampler"] = new_batch_sampler
1409-
kwargs["batch_size"] = dataloader.batch_size
14101394
else:
14111395
kwargs["batch_sampler"] = new_batch_sampler
14121396
dataloader = DataLoaderShard(

0 commit comments

Comments
 (0)