12
12
#See the License for the specific language governing permissions and
13
13
#limitations under the License.
14
14
15
- import sys
16
15
import copy
17
-
16
+ import importlib
18
17
import paddle .nn as nn
19
18
from paddle .jit import to_static
20
19
from paddle .static import InputSpec
21
20
22
- from . import backbone as backbone_zoo
21
+ from . import backbone , gears
22
+ from .backbone import *
23
23
from .gears import build_gear
24
24
from .utils import *
25
25
from .backbone .base .theseus_layer import TheseusLayer
@@ -35,28 +35,20 @@ def build_model(config, mode="train"):
35
35
arch_config = copy .deepcopy (config ["Arch" ])
36
36
model_type = arch_config .pop ("name" )
37
37
use_sync_bn = arch_config .pop ("use_sync_bn" , False )
38
-
39
- if hasattr (backbone_zoo , model_type ):
40
- model = ClassModel (model_type , ** arch_config )
41
- else :
42
- model = getattr (sys .modules [__name__ ], model_type )("ClassModel" ,
43
- ** arch_config )
44
-
38
+ mod = importlib .import_module (__name__ )
39
+ arch = getattr (mod , model_type )(** arch_config )
45
40
if use_sync_bn :
46
41
if config ["Global" ]["device" ] == "gpu" :
47
- model = nn .SyncBatchNorm .convert_sync_batchnorm (model )
42
+ arch = nn .SyncBatchNorm .convert_sync_batchnorm (arch )
48
43
else :
49
44
msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
50
45
logger .warning (msg )
51
46
52
- if isinstance (model , TheseusLayer ):
53
- prune_model (config , model )
54
- quantize_model (config , model , mode )
47
+ if isinstance (arch , TheseusLayer ):
48
+ prune_model (config , arch )
49
+ quantize_model (config , arch , mode )
55
50
56
- # set @to_static for benchmark, skip this by default.
57
- model = apply_to_static (config , model )
58
-
59
- return model
51
+ return arch
60
52
61
53
62
54
def apply_to_static (config , model ):
@@ -73,29 +65,12 @@ def apply_to_static(config, model):
73
65
return model
74
66
75
67
76
- # TODO(gaotingquan): export model
77
- class ClassModel (TheseusLayer ):
78
- def __init__ (self , model_type , ** config ):
79
- super ().__init__ ()
80
- if model_type == "ClassModel" :
81
- backbone_config = config ["Backbone" ]
82
- backbone_name = backbone_config .pop ("name" )
83
- else :
84
- backbone_name = model_type
85
- backbone_config = config
86
- self .backbone = getattr (backbone_zoo , backbone_name )(** backbone_config )
87
-
88
- def forward (self , batch ):
89
- x , label = batch [0 ], batch [1 ]
90
- return self .backbone (x )
91
-
92
-
93
68
class RecModel (TheseusLayer ):
94
69
def __init__ (self , ** config ):
95
70
super ().__init__ ()
96
71
backbone_config = config ["Backbone" ]
97
72
backbone_name = backbone_config .pop ("name" )
98
- self .backbone = getattr ( backbone_zoo , backbone_name )(** backbone_config )
73
+ self .backbone = eval ( backbone_name )(** backbone_config )
99
74
self .head_feature_from = config .get ('head_feature_from' , 'neck' )
100
75
101
76
if "BackboneStopLayer" in config :
@@ -112,8 +87,8 @@ def __init__(self, **config):
112
87
else :
113
88
self .head = None
114
89
115
- def forward (self , batch ):
116
- x , label = batch [ 0 ], batch [ 1 ]
90
+ def forward (self , x , label = None ):
91
+
117
92
out = dict ()
118
93
x = self .backbone (x )
119
94
out ["backbone" ] = x
@@ -165,8 +140,7 @@ def __init__(self,
165
140
load_dygraph_pretrain (
166
141
self .model_name_list [idx ], path = pretrained )
167
142
168
- def forward (self , batch ):
169
- x , label = batch [0 ], batch [1 ]
143
+ def forward (self , x , label = None ):
170
144
result_dict = dict ()
171
145
for idx , model_name in enumerate (self .model_name_list ):
172
146
if label is None :
@@ -184,8 +158,7 @@ def __init__(self,
184
158
** kargs ):
185
159
super ().__init__ (models , pretrained_list , freeze_params_list , ** kargs )
186
160
187
- def forward (self , batch ):
188
- x , label = batch [0 ], batch [1 ]
161
+ def forward (self , x , label = None ):
189
162
result_dict = dict ()
190
163
out = x
191
164
for idx , model_name in enumerate (self .model_name_list ):
@@ -195,4 +168,4 @@ def forward(self, batch):
195
168
else :
196
169
out = self .model_list [idx ](out , label )
197
170
result_dict .update (out )
198
- return result_dict
171
+ return result_dict
0 commit comments