Skip to content

Commit 17fc17a

Browse files
committed
(Fix) data_loader: Change Value Error to Warning for BatchSampler in sampler argument, and Fix Typos
1 parent 56647e3 commit 17fc17a

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/accelerate/data_loader.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def set_sampler(self, sampler):
960960
elif isinstance(sampler, Sampler):
961961
self.sampler.sampler = sampler
962962
else:
963-
raise ValueError(f"{sampler} must be of type torch.utills.data.Sampler or torch.utils.data.BatchSampler")
963+
raise ValueError(f"{sampler} must be of type torch.utils.data.Sampler or torch.utils.data.BatchSampler")
964964

965965

966966
def get_sampler(dataloader):
@@ -1150,8 +1150,10 @@ def prepare_data_loader(
11501150
new_dataset = dataloader.dataset
11511151
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
11521152
if isinstance(dataloader.sampler, BatchSampler):
1153-
raise ValueError(
1154-
"Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1153+
logger.warning(
1154+
"BatchSampler was passed to sampler argument."
1155+
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1156+
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
11551157
)
11561158

11571159
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
@@ -1347,8 +1349,10 @@ def skip_first_batches(dataloader, num_batches=0):
13471349

13481350
dataset = dataloader.dataset
13491351
if isinstance(dataloader.sampler, BatchSampler):
1350-
raise ValueError(
1351-
"Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1352+
logger.warning(
1353+
"BatchSampler was passed to sampler argument."
1354+
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1355+
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
13521356
)
13531357

13541358
if isinstance(dataset, IterableDataset):

0 commit comments

Comments
 (0)