Skip to content

Commit e7e4f68

Browse files
committed
Revert "refactor: build_train_func & build_eval_func"
This reverts commit 6bed0f5.
1 parent 6245b64 commit e7e4f68

File tree

6 files changed

+57
-51
lines changed

6 files changed

+57
-51
lines changed

ppcls/data/__init__.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
8888
random.seed(worker_seed)
8989

9090

91-
def build(config, mode, use_dali=False, seed=None):
91+
def build(config, mode, device, use_dali=False, seed=None):
9292
assert mode in [
9393
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
9494
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
@@ -167,7 +167,7 @@ def mix_collate_fn(batch):
167167
if batch_sampler is None:
168168
data_loader = DataLoader(
169169
dataset=dataset,
170-
places=paddle.device.get_device(),
170+
places=device,
171171
num_workers=num_workers,
172172
return_list=True,
173173
use_shared_memory=use_shared_memory,
@@ -179,7 +179,7 @@ def mix_collate_fn(batch):
179179
else:
180180
data_loader = DataLoader(
181181
dataset=dataset,
182-
places=paddle.device.get_device(),
182+
places=device,
183183
num_workers=num_workers,
184184
return_list=True,
185185
use_shared_memory=use_shared_memory,
@@ -218,7 +218,11 @@ def build_dataloader(engine):
218218
}
219219
if engine.mode == 'train':
220220
train_dataloader = build(
221-
engine.config["DataLoader"], "Train", use_dali, seed=None)
221+
engine.config["DataLoader"],
222+
"Train",
223+
engine.device,
224+
use_dali,
225+
seed=None)
222226
iter_per_epoch = len(train_dataloader) - 1 if platform.system(
223227
) == "Windows" else len(train_dataloader)
224228
if engine.config["Global"].get("iter_per_epoch", None):
@@ -231,23 +235,33 @@ def build_dataloader(engine):
231235

232236
if engine.config["DataLoader"].get('UnLabelTrain', None) is not None:
233237
dataloader_dict["UnLabelTrain"] = build(
234-
engine.config["DataLoader"], "UnLabelTrain", use_dali, seed=None)
238+
engine.config["DataLoader"],
239+
"UnLabelTrain",
240+
engine.device,
241+
use_dali,
242+
seed=None)
235243

236244
if engine.mode == "eval" or (engine.mode == "train" and
237245
engine.config["Global"]["eval_during_train"]):
238-
if engine.config["Global"][
239-
"eval_mode"] in ["classification", "adaface"]:
246+
if engine.eval_mode in ["classification", "adaface"]:
240247
dataloader_dict["Eval"] = build(
241-
engine.config["DataLoader"], "Eval", use_dali, seed=None)
242-
elif engine.config["Global"]["eval_mode"] == "retrieval":
248+
engine.config["DataLoader"],
249+
"Eval",
250+
engine.device,
251+
use_dali,
252+
seed=None)
253+
elif engine.eval_mode == "retrieval":
243254
if len(engine.config["DataLoader"]["Eval"].keys()) == 1:
244255
key = list(engine.config["DataLoader"]["Eval"].keys())[0]
245-
dataloader_dict["GalleryQuery"] = build(
246-
engine.config["DataLoader"]["Eval"], key, use_dali)
256+
dataloader_dict["GalleryQuery"] = build_dataloader(
257+
engine.config["DataLoader"]["Eval"], key, engine.device,
258+
use_dali)
247259
else:
248-
dataloader_dict["Gallery"] = build(
249-
engine.config["DataLoader"]["Eval"], "Gallery", use_dali)
250-
dataloader_dict["Query"] = build(
251-
engine.config["DataLoader"]["Eval"], "Query", use_dali)
260+
dataloader_dict["Gallery"] = build_dataloader(
261+
engine.config["DataLoader"]["Eval"], "Gallery",
262+
engine.device, use_dali)
263+
dataloader_dict["Query"] = build_dataloader(
264+
engine.config["DataLoader"]["Eval"], "Query",
265+
engine.device, use_dali)
252266

253267
return dataloader_dict

ppcls/engine/engine.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
from ppcls.data.utils.get_image_list import get_image_list
4040
from ppcls.data.postprocess import build_postprocess
4141
from ppcls.data import create_operators
42-
from .train import build_train_epoch_func
43-
from .evaluation import build_eval_func
42+
from ppcls.engine import train as train_method
4443
from ppcls.engine.train.utils import type_name
4544
from ppcls.engine import evaluation
4645
from ppcls.arch.gears.identity_head import IdentityHead
@@ -62,11 +61,22 @@ def __init__(self, config, mode="train"):
6261
self.vdl_writer = self._init_vdl()
6362

6463
# init train_func and eval_func
65-
self.train_epoch_func = build_train_epoch_func(self.config)
66-
self.eval_epoch_func = build_eval_func(self.config)
64+
self.train_mode = self.config["Global"].get("train_mode", None)
65+
if self.train_mode is None:
66+
self.train_epoch_func = train_method.train_epoch
67+
else:
68+
self.train_epoch_func = getattr(train_method,
69+
"train_epoch_" + self.train_mode)
70+
71+
self.eval_mode = self.config["Global"].get("eval_mode",
72+
"classification")
73+
assert self.eval_mode in [
74+
"classification", "retrieval", "adaface"
75+
], logger.error("Invalid eval mode: {}".format(self.eval_mode))
76+
self.eval_func = getattr(evaluation, self.eval_mode + "_eval")
6777

6878
# set device
69-
self._init_device()
79+
self.device = self._init_device()
7080

7181
# gradient accumulation
7282
self.update_freq = self.config["Global"].get("update_freq", 1)
@@ -385,7 +395,7 @@ def _init_device(self):
385395
assert device in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
386396
logger.info('train with paddle {} and device {}'.format(
387397
paddle.__version__, device))
388-
paddle.set_device(device)
398+
return paddle.set_device(device)
389399

390400
def _init_pretrained(self):
391401
if self.config["Global"]["pretrained_model"] is not None:

ppcls/engine/evaluation/__init__.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .classification import classification_eval
16-
from .retrieval import retrieval_eval
17-
from .adaface import adaface_eval
18-
19-
20-
def build_eval_func(config):
21-
eval_mode = config["Global"].get("eval_mode", None)
22-
if eval_mode is None:
23-
config["Global"]["eval_mode"] = "classification"
24-
return classification_eval
25-
else:
26-
return getattr(sys.modules[__name__], eval_mode + "_eval")
15+
from ppcls.engine.evaluation.classification import classification_eval
16+
from ppcls.engine.evaluation.retrieval import retrieval_eval
17+
from ppcls.engine.evaluation.adaface import adaface_eval

ppcls/engine/train/__init__.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from .train_metabin import train_epoch_metabin
16-
from .regular_train_epoch import regular_train_epoch
17-
from .train_fixmatch import train_epoch_fixmatch
18-
from .train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
19-
from .train_progressive import train_epoch_progressive
20-
21-
22-
def build_train_epoch_func(config):
23-
train_mode = config["Global"].get("train_mode", None)
24-
if train_mode is None:
25-
config["Global"]["train_mode"] = "regular_train"
26-
return regular_train_epoch
27-
else:
28-
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)
14+
from ppcls.engine.train.train import train_epoch
15+
from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
16+
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
17+
from ppcls.engine.train.train_progressive import train_epoch_progressive
18+
from ppcls.engine.train.train_metabin import train_epoch_metabin

ppcls/engine/train/regular_train_epoch.py renamed to ppcls/engine/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ppcls.utils import profiler
2020

2121

22-
def regular_train_epoch(engine, epoch_id, print_batch_step):
22+
def train_epoch(engine, epoch_id, print_batch_step):
2323
tic = time.time()
2424

2525
if not hasattr(engine, "train_dataloader_iter"):

ppcls/engine/train/train_progressive.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from ppcls.data import build_dataloader
1717
from ppcls.engine.train.utils import type_name
1818
from ppcls.utils import logger
19-
from .regular_train_epoch import regular_train_epoch
19+
20+
from .train import train_epoch
2021

2122

2223
def train_epoch_progressive(engine, epoch_id, print_batch_step):
@@ -68,4 +69,4 @@ def _change_dp_func(m):
6869
f")")
6970

7071
# 3. Train one epoch as usual at current stage
71-
regular_train_epoch(engine, epoch_id, print_batch_step)
72+
train_epoch(engine, epoch_id, print_batch_step)

0 commit comments

Comments
 (0)