4
4
from transforms import *
5
5
from torch .nn .init import normal , constant
6
6
7
+ import pretrainedmodels
7
8
import MLPmodule
8
9
9
10
class TSN (nn .Module ):
@@ -102,9 +103,13 @@ def _prepare_tsn(self, num_class):
102
103
103
104
def _prepare_base_model (self , base_model ):
104
105
105
- if 'resnet' in base_model or 'vgg' in base_model :
106
- self .base_model = getattr (torchvision .models , base_model )(True )
107
- self .base_model .last_layer_name = 'fc'
106
+ if 'resnet' in base_model or 'vgg' in base_model or 'squeezenet1_1' in base_model :
107
+ self .base_model = pretrainedmodels .__dict__ [base_model ](num_classes = 1000 , pretrained = 'imagenet' )
108
+ if base_model == 'squeezenet1_1' :
109
+ self .base_model = self .base_model .features
110
+ self .base_model .last_layer_name = '12'
111
+ else :
112
+ self .base_model .last_layer_name = 'fc'
108
113
self .input_size = 224
109
114
self .input_mean = [0.485 , 0.456 , 0.406 ]
110
115
self .input_std = [0.229 , 0.224 , 0.225 ]
@@ -116,38 +121,26 @@ def _prepare_base_model(self, base_model):
116
121
self .input_mean = [0.485 , 0.456 , 0.406 ] + [0 ] * 3 * self .new_length
117
122
self .input_std = self .input_std + [np .mean (self .input_std ) * 2 ] * 3 * self .new_length
118
123
elif base_model == 'BNInception' :
119
- import model_zoo
120
- self .base_model = getattr (model_zoo , base_model )()
121
- self .base_model .last_layer_name = 'fc'
124
+ self .base_model = pretrainedmodels .__dict__ ['bninception' ](num_classes = 1000 , pretrained = 'imagenet' )
125
+ self .base_model .last_layer_name = 'last_linear'
122
126
self .input_size = 224
123
127
self .input_mean = [104 , 117 , 128 ]
124
128
self .input_std = [1 ]
125
-
126
129
if self .modality == 'Flow' :
127
130
self .input_mean = [128 ]
128
131
elif self .modality == 'RGBDiff' :
129
132
self .input_mean = self .input_mean * (1 + self .new_length )
130
- elif self .modality == 'RGBFlow' :
131
- self .input_mean = self .input_mean * (self .new_length ) # NOTE: Check here if can be modified properly!
132
- elif base_model == 'InceptionV3' :
133
- import model_zoo
134
- self .base_model = getattr (model_zoo , base_model )()
135
- self .base_model .last_layer_name = 'top_cls_fc'
136
- self .input_size = 299
137
- self .input_mean = [104 ,117 ,128 ]
138
- self .input_std = [1 ]
133
+ elif 'resnext101' in base_model :
134
+ self .base_model = pretrainedmodels .__dict__ [base_model ](num_classes = 1000 , pretrained = 'imagenet' )
135
+ print (self .base_model )
136
+ self .base_model .last_layer_name = 'last_linear'
137
+ self .input_size = 224
138
+ self .input_mean = [0.485 , 0.456 , 0.406 ]
139
+ self .input_std = [0.229 , 0.224 , 0.225 ]
139
140
if self .modality == 'Flow' :
140
141
self .input_mean = [128 ]
141
142
elif self .modality == 'RGBDiff' :
142
- self .input_mean = self .input_mean * (1 + self .new_length )
143
-
144
- elif 'inception' in base_model :
145
- import model_zoo
146
- self .base_model = getattr (model_zoo , base_model )()
147
- self .base_model .last_layer_name = 'classif'
148
- self .input_size = 299
149
- self .input_mean = [0.5 ]
150
- self .input_std = [0.5 ]
143
+ self .input_mean = self .input_mean * (1 + self .new_length )
151
144
else :
152
145
raise ValueError ('Unknown base model: {}' .format (base_model ))
153
146
0 commit comments