Skip to content

Commit b70271f

Browse files
committed
[feat] support l37 split point for YOLOX-Darknet53
1 parent 2339312 commit b70271f

File tree

3 files changed

+70
-4
lines changed

3 files changed

+70
-4
lines changed

cfgs/vision_model/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ yolox_darknet53:
4646
conf_thres: 0.001
4747
nms_thres: 0.65
4848
weights: "weights/yolox/darknet53/yolox_darknet.pth"
49-
splits: "l13"
49+
splits: "l13" #"l37"

compressai_vision/model_wrappers/yolox.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __str__(self):
5454
return str(self.value)
5555

5656
Layer13_Single = "l13"
57+
Layer37_Single = "l37"
5758

5859

5960
@register_vision_model("yolox_darknet53")
@@ -86,6 +87,8 @@ def __init__(self, device: str, **kwargs):
8687
self.split_id = str(kwargs["splits"]).lower()
8788
if self.split_id == str(self.supported_split_points.Layer13_Single):
8889
self.split_layer_list = ["l13"]
90+
elif self.split_id == str(self.supported_split_points.Layer37_Single):
91+
self.split_layer_list = ["l37"]
8992
else:
9093
raise NotImplementedError
9194

@@ -111,6 +114,10 @@ def __init__(self, device: str, **kwargs):
111114
def SPLIT_L13(self):
112115
return str(self.supported_split_points.Layer13_Single)
113116

117+
@property
118+
def SPLIT_L37(self):
119+
return str(self.supported_split_points.Layer37_Single)
120+
114121
def input_to_features(self, x, device: str) -> Dict:
115122
"""Computes deep features at the intermediate layer(s) all the way from the input"""
116123

@@ -120,12 +127,14 @@ def input_to_features(self, x, device: str) -> Dict:
120127

121128
if self.split_id == self.SPLIT_L13:
122129
output = self._input_to_feature_at_l13(img)
123-
output["input_size"] = [input_size]
124-
return output
130+
elif self.split_id == self.SPLIT_L37:
131+
output = self._input_to_feature_at_l37(img)
125132
else:
126133
self.logger.error(f"Not supported split point {self.split_id}")
134+
raise NotImplementedError
127135

128-
raise NotImplementedError
136+
output["input_size"] = [input_size]
137+
return output
129138

130139
def features_to_output(self, x: Dict, device: str):
131140
"""Complete the downstream task from the intermediate deep features"""
@@ -136,6 +145,10 @@ def features_to_output(self, x: Dict, device: str):
136145
return self._feature_at_l13_to_output(
137146
x["data"], x["org_input_size"], x["input_size"]
138147
)
148+
elif self.split_id == self.SPLIT_L37:
149+
return self._feature_at_l37_to_output(
150+
x["data"], x["org_input_size"], x["input_size"]
151+
)
139152
else:
140153
self.logger.error(f"Not supported split points {self.split_id}")
141154

@@ -151,6 +164,17 @@ def _input_to_feature_at_l13(self, x):
151164

152165
return {"data": self.features_at_splits}
153166

167+
@torch.no_grad()
168+
def _input_to_feature_at_l37(self, x):
169+
"""Computes and return feature at layer 37 with 11th residual layer output all the way from the input"""
170+
171+
y = self.backbone.stem(x)
172+
y = self.backbone.dark2(y)
173+
y = self.backbone.dark3(y)
174+
self.features_at_splits[self.SPLIT_L37] = y
175+
176+
return {"data": self.features_at_splits}
177+
154178
@torch.no_grad()
155179
def _feature_at_l13_to_output(
156180
self, x: Dict, org_img_size: Dict, input_img_size: List
@@ -194,6 +218,45 @@ def _feature_at_l13_to_output(
194218

195219
return pred
196220

221+
@torch.no_grad()
222+
def _feature_at_l37_to_output(
223+
self, x: Dict, org_img_size: Dict, input_img_size: List
224+
):
225+
"""
226+
performs downstream task using the features from layer 37
227+
228+
YOLOX source codes are referenced for this function.
229+
<https://github.com/Megvii-BaseDetection/YOLOX/yolox/data/data_augment.py>
230+
231+
Unnecessary parts for split inference are removed or modified properly.
232+
233+
Please find the license statement in the downloaded original YOLOX source codes or at here:
234+
<https://github.com/Megvii-BaseDetection/YOLOX?tab=Apache-2.0-1-ov-file#readme>
235+
236+
"""
237+
238+
fp_lvl2 = x[self.SPLIT_L37]
239+
fp_lvl1 = self.backbone.dark4(fp_lvl2)
240+
fp_lvl0 = self.backbone.dark5(fp_lvl1)
241+
242+
# yolo branch 1
243+
b1_in = self.yolo_fpn.out1_cbl(fp_lvl0)
244+
b1_in = self.yolo_fpn.upsample(b1_in)
245+
b1_in = torch.cat([b1_in, fp_lvl1], 1)
246+
fp_lvl1 = self.yolo_fpn.out1(b1_in)
247+
248+
# yolo branch 2
249+
b2_in = self.yolo_fpn.out2_cbl(fp_lvl1)
250+
b2_in = self.yolo_fpn.upsample(b2_in)
251+
b2_in = torch.cat([b2_in, fp_lvl2], 1)
252+
fp_lvl2 = self.yolo_fpn.out2(b2_in)
253+
254+
outputs = self.head((fp_lvl2, fp_lvl1, fp_lvl0))
255+
256+
pred = postprocess(outputs, self.num_classes, self.conf_thres, self.nms_thres)
257+
258+
return pred
259+
197260
@torch.no_grad()
198261
def forward(self, x):
199262
"""Complete the downstream task with end-to-end manner all the way from the input"""

scripts/evaluation/default_yolox_darknet3_performance.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ fi
4040
COCO_2017_VAL_SRC="${TESTDATA_DIR}/coco2017"
4141

4242
# COCO 2017 Val - Detection with YOLOX-Darknet53
43+
44+
# option for split points "l13" or "l37"
45+
# ++vision_model.yolox_darknet53.splits="l37" \
4346
${ENTRY_CMD} --config-name=${CONF_NAME}.yaml \
4447
++pipeline.type=image \
4548
++pipeline.conformance.save_conformance_files=False \

0 commit comments

Comments
 (0)