@@ -54,6 +54,7 @@ def __str__(self):
54
54
return str (self .value )
55
55
56
56
Layer13_Single = "l13"
57
+ Layer37_Single = "l37"
57
58
58
59
59
60
@register_vision_model ("yolox_darknet53" )
@@ -86,6 +87,8 @@ def __init__(self, device: str, **kwargs):
86
87
self .split_id = str (kwargs ["splits" ]).lower ()
87
88
if self .split_id == str (self .supported_split_points .Layer13_Single ):
88
89
self .split_layer_list = ["l13" ]
90
+ elif self .split_id == str (self .supported_split_points .Layer37_Single ):
91
+ self .split_layer_list = ["l37" ]
89
92
else :
90
93
raise NotImplementedError
91
94
@@ -111,6 +114,10 @@ def __init__(self, device: str, **kwargs):
111
114
def SPLIT_L13 (self ):
112
115
return str (self .supported_split_points .Layer13_Single )
113
116
117
+ @property
118
+ def SPLIT_L37 (self ):
119
+ return str (self .supported_split_points .Layer37_Single )
120
+
114
121
def input_to_features (self , x , device : str ) -> Dict :
115
122
"""Computes deep features at the intermediate layer(s) all the way from the input"""
116
123
@@ -120,12 +127,14 @@ def input_to_features(self, x, device: str) -> Dict:
120
127
121
128
if self .split_id == self .SPLIT_L13 :
122
129
output = self ._input_to_feature_at_l13 (img )
123
- output [ "input_size" ] = [ input_size ]
124
- return output
130
+ elif self . split_id == self . SPLIT_L37 :
131
+ output = self . _input_to_feature_at_l37 ( img )
125
132
else :
126
133
self .logger .error (f"Not supported split point { self .split_id } " )
134
+ raise NotImplementedError
127
135
128
- raise NotImplementedError
136
+ output ["input_size" ] = [input_size ]
137
+ return output
129
138
130
139
def features_to_output (self , x : Dict , device : str ):
131
140
"""Complete the downstream task from the intermediate deep features"""
@@ -136,6 +145,10 @@ def features_to_output(self, x: Dict, device: str):
136
145
return self ._feature_at_l13_to_output (
137
146
x ["data" ], x ["org_input_size" ], x ["input_size" ]
138
147
)
148
+ elif self .split_id == self .SPLIT_L37 :
149
+ return self ._feature_at_l37_to_output (
150
+ x ["data" ], x ["org_input_size" ], x ["input_size" ]
151
+ )
139
152
else :
140
153
self .logger .error (f"Not supported split points { self .split_id } " )
141
154
@@ -151,6 +164,17 @@ def _input_to_feature_at_l13(self, x):
151
164
152
165
return {"data" : self .features_at_splits }
153
166
167
+ @torch .no_grad ()
168
+ def _input_to_feature_at_l37 (self , x ):
169
+ """Computes and return feature at layer 37 with 11th residual layer output all the way from the input"""
170
+
171
+ y = self .backbone .stem (x )
172
+ y = self .backbone .dark2 (y )
173
+ y = self .backbone .dark3 (y )
174
+ self .features_at_splits [self .SPLIT_L37 ] = y
175
+
176
+ return {"data" : self .features_at_splits }
177
+
154
178
@torch .no_grad ()
155
179
def _feature_at_l13_to_output (
156
180
self , x : Dict , org_img_size : Dict , input_img_size : List
@@ -194,6 +218,45 @@ def _feature_at_l13_to_output(
194
218
195
219
return pred
196
220
221
+ @torch .no_grad ()
222
+ def _feature_at_l37_to_output (
223
+ self , x : Dict , org_img_size : Dict , input_img_size : List
224
+ ):
225
+ """
226
+ performs downstream task using the features from layer 37
227
+
228
+ YOLOX source codes are referenced for this function.
229
+ <https://github.com/Megvii-BaseDetection/YOLOX/yolox/data/data_augment.py>
230
+
231
+ Unnecessary parts for split inference are removed or modified properly.
232
+
233
+ Please find the license statement in the downloaded original YOLOX source codes or at here:
234
+ <https://github.com/Megvii-BaseDetection/YOLOX?tab=Apache-2.0-1-ov-file#readme>
235
+
236
+ """
237
+
238
+ fp_lvl2 = x [self .SPLIT_L37 ]
239
+ fp_lvl1 = self .backbone .dark4 (fp_lvl2 )
240
+ fp_lvl0 = self .backbone .dark5 (fp_lvl1 )
241
+
242
+ # yolo branch 1
243
+ b1_in = self .yolo_fpn .out1_cbl (fp_lvl0 )
244
+ b1_in = self .yolo_fpn .upsample (b1_in )
245
+ b1_in = torch .cat ([b1_in , fp_lvl1 ], 1 )
246
+ fp_lvl1 = self .yolo_fpn .out1 (b1_in )
247
+
248
+ # yolo branch 2
249
+ b2_in = self .yolo_fpn .out2_cbl (fp_lvl1 )
250
+ b2_in = self .yolo_fpn .upsample (b2_in )
251
+ b2_in = torch .cat ([b2_in , fp_lvl2 ], 1 )
252
+ fp_lvl2 = self .yolo_fpn .out2 (b2_in )
253
+
254
+ outputs = self .head ((fp_lvl2 , fp_lvl1 , fp_lvl0 ))
255
+
256
+ pred = postprocess (outputs , self .num_classes , self .conf_thres , self .nms_thres )
257
+
258
+ return pred
259
+
197
260
@torch .no_grad ()
198
261
def forward (self , x ):
199
262
"""Complete the downstream task with end-to-end manner all the way from the input"""
0 commit comments