19
19
20
20
import torch
21
21
from packaging import version
22
- from torch .utils .data import BatchSampler , DataLoader , IterableDataset , RandomSampler
22
+ from torch .utils .data import BatchSampler , DataLoader , IterableDataset , RandomSampler , Sampler
23
23
24
24
from .logging import get_logger
25
25
from .state import DistributedType , GradientState , PartialState , is_torch_xla_available
@@ -631,13 +631,10 @@ def get_sampler(self):
631
631
return get_sampler (self )
632
632
633
633
def set_sampler (self , sampler ):
634
- sampler_is_batch_sampler = isinstance (self .sampler , BatchSampler )
635
- if sampler_is_batch_sampler :
634
+ if isinstance (sampler , BatchSampler ):
635
+ self .sampler .batch_sampler = sampler
636
+ elif isinstance (sampler , Sampler ):
636
637
self .sampler .sampler = sampler
637
- else :
638
- self .batch_sampler .sampler = sampler
639
- if hasattr (self .batch_sampler , "batch_sampler" ):
640
- self .batch_sampler .batch_sampler .sampler = sampler
641
638
642
639
643
640
if is_torch_xla_available ():
@@ -958,13 +955,12 @@ def get_sampler(self):
958
955
return get_sampler (self )
959
956
960
957
def set_sampler (self , sampler ):
961
- sampler_is_batch_sampler = isinstance (self .sampler , BatchSampler )
962
- if sampler_is_batch_sampler :
958
+ if isinstance (sampler , BatchSampler ):
959
+ self .sampler .batch_sampler = sampler
960
+ elif isinstance (sampler , Sampler ):
963
961
self .sampler .sampler = sampler
964
962
else :
965
- self .batch_sampler .sampler = sampler
966
- if hasattr (self .batch_sampler , "batch_sampler" ):
967
- self .batch_sampler .batch_sampler .sampler = sampler
963
+ raise ValueError (f"{ sampler } must be of type torch.utills.data.Sampler or torch.utils.data.BatchSampler" )
968
964
969
965
970
966
def get_sampler (dataloader ):
@@ -977,10 +973,8 @@ def get_sampler(dataloader):
977
973
Returns:
978
974
`torch.utils.data.Sampler`: The sampler associated to the dataloader
979
975
"""
980
- sampler_is_batch_sampler = isinstance (dataloader .sampler , BatchSampler )
981
- if sampler_is_batch_sampler :
982
- sampler = getattr (dataloader .sampler , "sampler" , None )
983
- else :
976
+ sampler = getattr (dataloader .sampler , "sampler" , None )
977
+ if not sampler :
984
978
sampler = getattr (dataloader .batch_sampler , "sampler" , None )
985
979
return sampler
986
980
@@ -1155,11 +1149,16 @@ def prepare_data_loader(
1155
1149
1156
1150
new_dataset = dataloader .dataset
1157
1151
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
1152
+ if isinstance (dataloader .sampler , BatchSampler ):
1153
+ raise ValueError (
1154
+ "Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1155
+ )
1156
+
1158
1157
new_batch_sampler = dataloader .batch_sampler if not isinstance (new_dataset , IterableDataset ) else None
1159
- sampler_is_batch_sampler = isinstance ( dataloader . sampler , BatchSampler )
1158
+
1160
1159
synchronized_generator = None
1160
+ sampler = dataloader .sampler
1161
1161
1162
- sampler = get_sampler (dataloader )
1163
1162
if isinstance (sampler , RandomSampler ) and use_seedable_sampler :
1164
1163
# When iterating through the dataloader during distributed processes
1165
1164
# we want to ensure that on each process we are iterating through the same
@@ -1208,9 +1207,8 @@ def prepare_data_loader(
1208
1207
seed = int (torch .empty ((), dtype = torch .int64 ).random_ ().item ())
1209
1208
sampler .generator .manual_seed (seed )
1210
1209
synchronized_generator = sampler .generator
1211
- batch_sampler = dataloader .sampler if sampler_is_batch_sampler else dataloader .batch_sampler
1212
1210
new_batch_sampler = BatchSamplerShard (
1213
- batch_sampler ,
1211
+ dataloader . batch_sampler ,
1214
1212
num_processes = num_processes ,
1215
1213
process_index = process_index ,
1216
1214
split_batches = split_batches ,
@@ -1254,19 +1252,6 @@ def prepare_data_loader(
1254
1252
torch_device_mesh = torch_device_mesh ,
1255
1253
** kwargs ,
1256
1254
)
1257
- elif sampler_is_batch_sampler :
1258
- dataloader = DataLoaderShard (
1259
- new_dataset ,
1260
- device = device if put_on_device and state .distributed_type != DistributedType .XLA else None ,
1261
- sampler = new_batch_sampler ,
1262
- batch_size = dataloader .batch_size ,
1263
- rng_types = rng_types ,
1264
- _drop_last = dataloader .drop_last ,
1265
- _non_blocking = non_blocking ,
1266
- synchronized_generator = synchronized_generator ,
1267
- use_stateful_dataloader = use_stateful_dataloader ,
1268
- ** kwargs ,
1269
- )
1270
1255
else :
1271
1256
dataloader = DataLoaderShard (
1272
1257
new_dataset ,
@@ -1361,13 +1346,15 @@ def skip_first_batches(dataloader, num_batches=0):
1361
1346
dataloader = dataloader .dataloader
1362
1347
1363
1348
dataset = dataloader .dataset
1364
- sampler_is_batch_sampler = False
1349
+ if isinstance (dataloader .sampler , BatchSampler ):
1350
+ raise ValueError (
1351
+ "Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1352
+ )
1353
+
1365
1354
if isinstance (dataset , IterableDataset ):
1366
1355
new_batch_sampler = None
1367
1356
else :
1368
- sampler_is_batch_sampler = isinstance (dataloader .sampler , BatchSampler )
1369
- batch_sampler = dataloader .sampler if sampler_is_batch_sampler else dataloader .batch_sampler
1370
- new_batch_sampler = SkipBatchSampler (batch_sampler , skip_batches = num_batches )
1357
+ new_batch_sampler = SkipBatchSampler (dataloader .batch_sampler , skip_batches = num_batches )
1371
1358
1372
1359
# We ignore all of those since they are all dealt with by our new_batch_sampler
1373
1360
ignore_kwargs = [
@@ -1404,9 +1391,6 @@ def skip_first_batches(dataloader, num_batches=0):
1404
1391
if new_batch_sampler is None :
1405
1392
# Need to manually skip batches in the dataloader
1406
1393
kwargs ["skip_batches" ] = num_batches
1407
- elif sampler_is_batch_sampler :
1408
- kwargs ["sampler" ] = new_batch_sampler
1409
- kwargs ["batch_size" ] = dataloader .batch_size
1410
1394
else :
1411
1395
kwargs ["batch_sampler" ] = new_batch_sampler
1412
1396
dataloader = DataLoaderShard (
0 commit comments