Skip to content

Commit f2fc43b

Browse files
committed
Revert "refactor: mv all dataloaders to engine.dataloader_dict"
This reverts commit 284e2a6.
1 parent a1e840e commit f2fc43b

File tree

7 files changed

+47
-60
lines changed

7 files changed

+47
-60
lines changed

ppcls/data/__init__.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -187,37 +187,10 @@ def mix_collate_fn(batch):
187187
collate_fn=batch_collate_fn,
188188
worker_init_fn=init_fn)
189189

190-
total_samples = len(
191-
data_loader.dataset) if not use_dali else data_loader.size
192-
max_iter = len(data_loader) - 1 if platform.system() == "Windows" else len(
193-
data_loader)
194-
data_loader.max_iter = max_iter
195-
data_loader.total_samples = total_samples
196-
197190
logger.debug("build data_loader({}) success...".format(data_loader))
198191
return data_loader
199192

200193

201-
# TODO(gaotingquan): perf
202-
class DataIterator(object):
203-
def __init__(self, dataloader, use_dali=False):
204-
self.dataloader = dataloader
205-
self.use_dali = use_dali
206-
self.iterator = iter(dataloader)
207-
208-
def get_batch(self):
209-
# fetch data batch from dataloader
210-
try:
211-
batch = next(self.iterator)
212-
except Exception:
213-
# NOTE: reset DALI dataloader manually
214-
if self.use_dali:
215-
self.dataloader.reset()
216-
self.iterator = iter(self.dataloader)
217-
batch = next(self.iterator)
218-
return batch
219-
220-
221194
def build_dataloader(engine):
222195
if "class_num" in engine.config["Global"]:
223196
global_class_num = engine.config["Global"]["class_num"]
@@ -249,15 +222,12 @@ def build_dataloader(engine):
249222
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
250223
) == "Windows" else len(train_dataloader)
251224
if engine.config["Global"].get("iter_per_epoch", None):
252-
# TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global
253225
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
254226
iter_per_epoch = engine.config["Global"].get("iter_per_epoch")
255227
iter_per_epoch = iter_per_epoch // engine.update_freq * engine.update_freq
256-
# engine.iter_per_epoch = iter_per_epoch
228+
engine.iter_per_epoch = iter_per_epoch
257229
train_dataloader.iter_per_epoch = iter_per_epoch
258230
dataloader_dict["Train"] = train_dataloader
259-
# TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True
260-
dataloader_dict["TrainIter"] = DataIterator(train_dataloader, use_dali)
261231

262232
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
263233
dataloader_dict["UnLabelTrain"] = build(
@@ -279,4 +249,5 @@ def build_dataloader(engine):
279249
engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
280250
dataloader_dict["Query"] = build(
281251
engine.config["DataLoader"]["Eval"], "Query", use_dali)
252+
282253
return dataloader_dict

ppcls/engine/engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, config, mode="train"):
6363

6464
# init train_func and eval_func
6565
self.train_epoch_func = build_train_epoch_func(self.config)
66-
self.eval_func = build_eval_func(self.config)
66+
self.eval_epoch_func = build_eval_func(self.config)
6767

6868
# set device
6969
self._init_device()
@@ -73,6 +73,12 @@ def __init__(self, config, mode="train"):
7373

7474
# build dataloader
7575
self.dataloader_dict = build_dataloader(self)
76+
self.train_dataloader, self.unlabel_train_dataloader, self.eval_dataloader = self.dataloader_dict[
77+
"Train"], self.dataloader_dict[
78+
"UnLabelTrain"], self.dataloader_dict["Eval"]
79+
self.gallery_query_dataloader, self.gallery_dataloader, self.query_dataloader = self.dataloader_dict[
80+
"GalleryQuery"], self.dataloader_dict[
81+
"Gallery"], self.dataloader_dict["Query"]
7682

7783
# build loss
7884
self.train_loss_func, self.unlabel_train_loss_func, self.eval_loss_func = build_loss(
@@ -88,7 +94,9 @@ def __init__(self, config, mode="train"):
8894
self._init_pretrained()
8995

9096
# build optimizer
91-
self.optimizer, self.lr_sch = build_optimizer(self)
97+
self.optimizer, self.lr_sch = build_optimizer(
98+
self.config, self.train_dataloader,
99+
[self.model, self.train_loss_func])
92100

93101
# AMP training and evaluating
94102
self._init_amp()

ppcls/engine/evaluation/classification.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def classification_eval(engine, epoch_id=0):
3535
print_batch_step = engine.config["Global"]["print_batch_step"]
3636

3737
tic = time.time()
38-
total_samples = engine.dataloader_dict["Eval"].total_samples
3938
accum_samples = 0
40-
max_iter = engine.dataloader_dict["Eval"].max_iter
41-
for iter_id, batch in enumerate(engine.dataloader_dict["Eval"]):
39+
total_samples = len(
40+
engine.eval_dataloader.
41+
dataset) if not engine.use_dali else engine.eval_dataloader.size
42+
max_iter = len(engine.eval_dataloader) - 1 if platform.system(
43+
) == "Windows" else len(engine.eval_dataloader)
44+
for iter_id, batch in enumerate(engine.eval_dataloader):
4245
if iter_id >= max_iter:
4346
break
4447
if iter_id == 5:
@@ -58,9 +61,9 @@ def classification_eval(engine, epoch_id=0):
5861
"flatten_contiguous_range", "greater_than"
5962
},
6063
level=engine.amp_level):
61-
out = engine.model(batch)
64+
out = engine.model(batch[0])
6265
else:
63-
out = engine.model(batch)
66+
out = engine.model(batch[0])
6467

6568
# just for DistributedBatchSampler issue: repeat sampling
6669
current_samples = batch_size * paddle.distributed.get_world_size()
@@ -92,8 +95,7 @@ def classification_eval(engine, epoch_id=0):
9295
paddle.distributed.all_gather(pred_list, out)
9396
preds = paddle.concat(pred_list, 0)
9497

95-
if accum_samples > total_samples and not engine.config[
96-
"Global"].get("use_dali", False):
98+
if accum_samples > total_samples and not engine.use_dali:
9799
if isinstance(preds, list):
98100
preds = [
99101
pred[:total_samples + current_samples - accum_samples]
@@ -149,11 +151,12 @@ def classification_eval(engine, epoch_id=0):
149151
])
150152
metric_msg += ", {}".format(engine.eval_metric_func.avg_info)
151153
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
152-
epoch_id, iter_id, max_iter, metric_msg, time_msg, ips_msg))
154+
epoch_id, iter_id,
155+
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
153156

154157
tic = time.time()
155-
if engine.config["Global"].get("use_dali", False):
156-
engine.dataloader_dict["Eval"].reset()
158+
if engine.use_dali:
159+
engine.eval_dataloader.reset()
157160

158161
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
159162
metric_msg = ", ".join([

ppcls/engine/train/regular_train_epoch.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,19 @@
2222
def regular_train_epoch(engine, epoch_id, print_batch_step):
2323
tic = time.time()
2424

25-
for iter_id in range(engine.dataloader_dict["Train"].iter_per_epoch):
26-
batch = engine.dataloader_dict["TrainIter"].get_batch()
25+
if not hasattr(engine, "train_dataloader_iter"):
26+
engine.train_dataloader_iter = iter(engine.train_dataloader)
27+
28+
for iter_id in range(engine.iter_per_epoch):
29+
# fetch data batch from dataloader
30+
try:
31+
batch = next(engine.train_dataloader_iter)
32+
except Exception:
33+
# NOTE: reset DALI dataloader manually
34+
if engine.use_dali:
35+
engine.train_dataloader.reset()
36+
engine.train_dataloader_iter = iter(engine.train_dataloader)
37+
batch = next(engine.train_dataloader_iter)
2738

2839
profiler.add_profiler_step(engine.config["profiler_options"])
2940
if iter_id == 5:

ppcls/engine/train/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
5454
ips_msg = "ips: {:.5f} samples/s".format(
5555
batch_size / trainer.time_info["batch_cost"].avg)
5656

57-
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
58-
) * trainer.dataloader_dict["Train"].iter_per_epoch - iter_id
59-
) * trainer.time_info["batch_cost"].avg
57+
eta_sec = (
58+
(trainer.config["Global"]["epochs"] - epoch_id + 1) *
59+
trainer.iter_per_epoch - iter_id) * trainer.time_info["batch_cost"].avg
6060
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
6161
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
62-
epoch_id, trainer.config["Global"][
63-
"epochs"], iter_id, trainer.dataloader_dict["Train"]
64-
.iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
62+
epoch_id, trainer.config["Global"]["epochs"], iter_id, trainer.
63+
iter_per_epoch, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
6564

6665
for i, lr in enumerate(trainer.lr_sch):
6766
logger.scaler(

ppcls/metric/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def build_metrics(engine):
7070
if mode == 'train' and "Metric" in config and "Train" in config[
7171
"Metric"] and config["Metric"]["Train"]:
7272
metric_config = config["Metric"]["Train"]
73-
if hasattr(engine.dataloader_dict["Train"],
74-
"collate_fn") and engine.dataloader_dict[
75-
"Train"].collate_fn is not None:
73+
if hasattr(engine.train_dataloader, "collate_fn"
74+
) and engine.train_dataloader.collate_fn is not None:
7675
for m_idx, m in enumerate(metric_config):
7776
if "TopkAcc" in m:
7877
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."

ppcls/optimizer/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,11 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
4545

4646

4747
# model_list is None in static graph
48-
def build_optimizer(engine):
49-
if engine.mode != "train":
50-
return None, None
51-
config, iter_per_epoch, model_list = engine.config, engine.dataloader_dict[
52-
"Train"].iter_per_epoch, [engine.mode, engine.train_loss_func]
48+
def build_optimizer(config, dataloader, model_list=None):
5349
optim_config = copy.deepcopy(config["Optimizer"])
5450
epochs = config["Global"]["epochs"]
5551
update_freq = config["Global"].get("update_freq", 1)
56-
step_each_epoch = iter_per_epoch // update_freq
52+
step_each_epoch = dataloader.iter_per_epoch // update_freq
5753
if isinstance(optim_config, dict):
5854
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
5955
optim_name = optim_config.pop("name")

0 commit comments

Comments
 (0)