Skip to content

Commit 6aabb94

Browse files
committed
Revert "refactor: add ClassModel to unify model forward interface"
This reverts commit 75a20ba.
1 parent e7e4f68 commit 6aabb94

File tree

3 files changed

+36
-47
lines changed

3 files changed

+36
-47
lines changed

ppcls/arch/__init__.py

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
#See the License for the specific language governing permissions and
1313
#limitations under the License.
1414

15-
import sys
1615
import copy
17-
16+
import importlib
1817
import paddle.nn as nn
1918
from paddle.jit import to_static
2019
from paddle.static import InputSpec
2120

22-
from . import backbone as backbone_zoo
21+
from . import backbone, gears
22+
from .backbone import *
2323
from .gears import build_gear
2424
from .utils import *
2525
from .backbone.base.theseus_layer import TheseusLayer
@@ -35,28 +35,20 @@ def build_model(config, mode="train"):
3535
arch_config = copy.deepcopy(config["Arch"])
3636
model_type = arch_config.pop("name")
3737
use_sync_bn = arch_config.pop("use_sync_bn", False)
38-
39-
if hasattr(backbone_zoo, model_type):
40-
model = ClassModel(model_type, **arch_config)
41-
else:
42-
model = getattr(sys.modules[__name__], model_type)("ClassModel",
43-
**arch_config)
44-
38+
mod = importlib.import_module(__name__)
39+
arch = getattr(mod, model_type)(**arch_config)
4540
if use_sync_bn:
4641
if config["Global"]["device"] == "gpu":
47-
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
42+
arch = nn.SyncBatchNorm.convert_sync_batchnorm(arch)
4843
else:
4944
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
5045
logger.warning(msg)
5146

52-
if isinstance(model, TheseusLayer):
53-
prune_model(config, model)
54-
quantize_model(config, model, mode)
47+
if isinstance(arch, TheseusLayer):
48+
prune_model(config, arch)
49+
quantize_model(config, arch, mode)
5550

56-
# set @to_static for benchmark, skip this by default.
57-
model = apply_to_static(config, model)
58-
59-
return model
51+
return arch
6052

6153

6254
def apply_to_static(config, model):
@@ -73,29 +65,12 @@ def apply_to_static(config, model):
7365
return model
7466

7567

76-
# TODO(gaotingquan): export model
77-
class ClassModel(TheseusLayer):
78-
def __init__(self, model_type, **config):
79-
super().__init__()
80-
if model_type == "ClassModel":
81-
backbone_config = config["Backbone"]
82-
backbone_name = backbone_config.pop("name")
83-
else:
84-
backbone_name = model_type
85-
backbone_config = config
86-
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
87-
88-
def forward(self, batch):
89-
x, label = batch[0], batch[1]
90-
return self.backbone(x)
91-
92-
9368
class RecModel(TheseusLayer):
9469
def __init__(self, **config):
9570
super().__init__()
9671
backbone_config = config["Backbone"]
9772
backbone_name = backbone_config.pop("name")
98-
self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
73+
self.backbone = eval(backbone_name)(**backbone_config)
9974
self.head_feature_from = config.get('head_feature_from', 'neck')
10075

10176
if "BackboneStopLayer" in config:
@@ -112,8 +87,8 @@ def __init__(self, **config):
11287
else:
11388
self.head = None
11489

115-
def forward(self, batch):
116-
x, label = batch[0], batch[1]
90+
def forward(self, x, label=None):
91+
11792
out = dict()
11893
x = self.backbone(x)
11994
out["backbone"] = x
@@ -165,8 +140,7 @@ def __init__(self,
165140
load_dygraph_pretrain(
166141
self.model_name_list[idx], path=pretrained)
167142

168-
def forward(self, batch):
169-
x, label = batch[0], batch[1]
143+
def forward(self, x, label=None):
170144
result_dict = dict()
171145
for idx, model_name in enumerate(self.model_name_list):
172146
if label is None:
@@ -184,8 +158,7 @@ def __init__(self,
184158
**kargs):
185159
super().__init__(models, pretrained_list, freeze_params_list, **kargs)
186160

187-
def forward(self, batch):
188-
x, label = batch[0], batch[1]
161+
def forward(self, x, label=None):
189162
result_dict = dict()
190163
out = x
191164
for idx, model_name in enumerate(self.model_name_list):
@@ -195,4 +168,4 @@ def forward(self, batch):
195168
else:
196169
out = self.model_list[idx](out, label)
197170
result_dict.update(out)
198-
return result_dict
171+
return result_dict

ppcls/engine/engine.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ppcls.utils.config import print_config
2929
from ppcls.data import build_dataloader
3030
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
31+
from ppcls.arch import apply_to_static
3132
from ppcls.loss import build_loss
3233
from ppcls.metric import build_metrics
3334
from ppcls.optimizer import build_optimizer
@@ -56,10 +57,18 @@ def __init__(self, config, mode="train"):
5657

5758
# init logger
5859
init_logger(self.config, mode=mode)
60+
print_config(config)
5961

6062
# for visualdl
6163
self.vdl_writer = self._init_vdl()
6264

65+
# is_rec
66+
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
67+
False):
68+
self.is_rec = True
69+
else:
70+
self.is_rec = False
71+
6372
# init train_func and eval_func
6473
self.train_mode = self.config["Global"].get("train_mode", None)
6574
if self.train_mode is None:
@@ -99,6 +108,8 @@ def __init__(self, config, mode="train"):
99108

100109
# build model
101110
self.model = build_model(self.config, self.mode)
111+
# set @to_static for benchmark, skip this by default.
112+
apply_to_static(self.config, self.model)
102113

103114
# load_pretrain
104115
self._init_pretrained()
@@ -114,8 +125,6 @@ def __init__(self, config, mode="train"):
114125
# for distributed
115126
self._init_dist()
116127

117-
print_config(config)
118-
119128
def train(self):
120129
assert self.mode == "train"
121130
print_batch_step = self.config['Global']['print_batch_step']

ppcls/engine/train/train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
5555
"flatten_contiguous_range", "greater_than"
5656
},
5757
level=amp_level):
58-
out = engine.model(batch)
58+
out = forward(engine, batch)
5959
loss_dict = engine.train_loss_func(out, batch[1])
6060
else:
61-
out = engine.model(batch)
61+
out = forward(engine, batch)
6262
loss_dict = engine.train_loss_func(out, batch[1])
6363

6464
# loss
@@ -104,3 +104,10 @@ def train_epoch(engine, epoch_id, print_batch_step):
104104
if getattr(engine.lr_sch[i], "by_epoch", False) and \
105105
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
106106
engine.lr_sch[i].step()
107+
108+
109+
def forward(engine, batch):
110+
if not engine.is_rec:
111+
return engine.model(batch[0])
112+
else:
113+
return engine.model(batch[0], batch[1])

0 commit comments

Comments
 (0)