@@ -960,7 +960,7 @@ def set_sampler(self, sampler):
960
960
elif isinstance (sampler , Sampler ):
961
961
self .sampler .sampler = sampler
962
962
else :
963
- raise ValueError (f"{ sampler } must be of type torch.utills .data.Sampler or torch.utils.data.BatchSampler" )
963
+ raise ValueError (f"{ sampler } must be of type torch.utils .data.Sampler or torch.utils.data.BatchSampler" )
964
964
965
965
966
966
def get_sampler (dataloader ):
@@ -1150,8 +1150,10 @@ def prepare_data_loader(
1150
1150
new_dataset = dataloader .dataset
1151
1151
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1152
1152
if isinstance (dataloader .sampler , BatchSampler ):
1153
- raise ValueError (
1154
- "Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1153
+ logger .warning (
1154
+ "BatchSampler was passed to sampler argument."
1155
+ "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1156
+ "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
1155
1157
)
1156
1158
1157
1159
new_batch_sampler = dataloader .batch_sampler if not isinstance (new_dataset , IterableDataset ) else None
@@ -1347,8 +1349,10 @@ def skip_first_batches(dataloader, num_batches=0):
1347
1349
1348
1350
dataset = dataloader .dataset
1349
1351
if isinstance (dataloader .sampler , BatchSampler ):
1350
- raise ValueError (
1351
- "Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1352
+ logger .warning (
1353
+ "BatchSampler was passed to sampler argument."
1354
+ "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1355
+ "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
1352
1356
)
1353
1357
1354
1358
if isinstance (dataset , IterableDataset ):
0 commit comments