Skip to content

Commit 3f8c6ec

Browse files
committed
add ONNX deployment
1 parent f3df59e commit 3f8c6ec

File tree

21 files changed

+305
-89
lines changed

21 files changed

+305
-89
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ I have provided a bash file `train_ddp.sh` that enables DDP training. I hope som
139139
| Model | Backbone | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
140140
|---------------|--------------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
141141
| YOLOX-N | CSPDarkNet-N | 640 | 300 | 31.1 | 49.5 | 7.5 | 2.3 | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
142-
| YOLOX-S | CSPDarkNet-S | 640 | 300 | | | 27.1 | 9.0 | |
142+
| YOLOX-S | CSPDarkNet-S | 640 | 300 | | | 26.8 | 8.9 | |
143143
| YOLOX-M | CSPDarkNet-M | 640 | 300 | | | 74.3 | 25.4 | |
144144
| YOLOX-L | CSPDarkNet-L | 640 | 300 | | | 155.4 | 54.2 | |
145145

@@ -360,3 +360,7 @@ python track.py --mode video \
360360
Results:
361361

362362
![image](./img_files/video_tracking_demo.gif)
363+
364+
365+
## Deployment
366+
1. [ONNX export and an ONNXRuntime](./deployment/ONNXRuntime/)

deployment/ONNXRuntime/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
## YOLO ONNXRuntime
2+
3+
4+
### Convert Your Model to ONNX
5+
6+
First, you should move to <FreeYOLO_HOME> by:
7+
```shell
8+
cd <FreeYOLO_HOME>
9+
cd tools/
10+
```
11+
Then, you can:
12+
13+
1. Convert a standard YOLO model by:
14+
```shell
15+
python3 export_onnx.py -m yolov1 --weight ../weight/coco/yolov1/yolov1_coco.pth -nc 80 --img_size 640
16+
```
17+
18+
Notes:
19+
* -n: specify a model name. The model name must be one of the [yolox-s,m,l,x and yolox-nano, yolox-tiny, yolov3]
20+
* -c: the model you have trained
21+
* -o: opset version, default 11. **However, if you will further convert your onnx model to [OpenVINO](https://github.com/Megvii-BaseDetection/YOLOX/demo/OpenVINO/), please specify the opset version to 10.**
22+
* --no-onnxsim: disable onnxsim
23+
* To customize an input shape for onnx model, modify the following code in tools/export_onnx.py:
24+
25+
```python
26+
dummy_input = torch.randn(args.batch_size, 3, args.img_size, args.img_size)
27+
```
28+
29+
### ONNXRuntime Demo
30+
31+
Step1.
32+
```shell
33+
cd <YOLOX_HOME>/deployment/ONNXRuntime
34+
```
35+
36+
Step2.
37+
```shell
38+
python3 onnx_inference.py --weight ../../weights/onnx/11/yolov1.onnx -i ../test_image.jpg -s 0.3 --img_size 640
39+
```
40+
Notes:
41+
* --weight: your converted onnx model
42+
* -i: input_image
43+
* -s: score threshold for visualization.
44+
* --img_size: should be consistent with the shape you used for onnx convertion.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# Copyright (c) Megvii, Inc. and its affiliates.
4+
5+
import argparse
6+
import os
7+
8+
import cv2
9+
import time
10+
import numpy as np
11+
import sys
12+
sys.path.append('../../')
13+
14+
import onnxruntime
15+
from utils.misc import PreProcessor, PostProcessor
16+
from utils.vis_tools import visualize
17+
18+
19+
def make_parser():
20+
parser = argparse.ArgumentParser("onnxruntime inference sample")
21+
parser.add_argument("-m", "--model", type=str, default="../../weights/onnx/11/yolov1.onnx",
22+
help="Input your onnx model.")
23+
parser.add_argument("-i", "--image_path", type=str, default='../test_image.jpg',
24+
help="Path to your input image.")
25+
parser.add_argument("-o", "--output_dir", type=str, default='../../det_results/onnx/',
26+
help="Path to your output directory.")
27+
parser.add_argument("-s", "--score_thr", type=float, default=0.35,
28+
help="Score threshould to filter the result.")
29+
parser.add_argument("-size", "--img_size", type=int, default=640,
30+
help="Specify an input shape for inference.")
31+
return parser
32+
33+
34+
if __name__ == '__main__':
35+
args = make_parser().parse_args()
36+
37+
# class color for better visualization
38+
np.random.seed(0)
39+
class_colors = [(np.random.randint(255),
40+
np.random.randint(255),
41+
np.random.randint(255)) for _ in range(80)]
42+
43+
# preprocessor
44+
prepocess = PreProcessor(img_size=args.img_size)
45+
46+
# postprocessor
47+
postprocess = PostProcessor(num_classes=80, conf_thresh=args.score_thr, nms_thresh=0.5)
48+
49+
# read an image
50+
input_shape = tuple([args.img_size, args.img_size])
51+
origin_img = cv2.imread(args.image_path)
52+
53+
# preprocess
54+
x, ratio = prepocess(origin_img)
55+
56+
t0 = time.time()
57+
# inference
58+
session = onnxruntime.InferenceSession(args.model)
59+
60+
ort_inputs = {session.get_inputs()[0].name: x[None, :, :, :]}
61+
output = session.run(None, ort_inputs)
62+
print("inference time: {:.1f} ms".format((time.time() - t0)*1000))
63+
64+
t0 = time.time()
65+
# post process
66+
bboxes, scores, labels = postprocess(output[0])
67+
bboxes /= ratio
68+
print("post-process time: {:.1f} ms".format((time.time() - t0)*1000))
69+
70+
# visualize detection
71+
origin_img = visualize(
72+
img=origin_img,
73+
bboxes=bboxes,
74+
scores=scores,
75+
labels=labels,
76+
vis_thresh=args.score_thr,
77+
class_colors=class_colors
78+
)
79+
80+
# show
81+
cv2.imshow('onnx detection', origin_img)
82+
cv2.waitKey(0)
83+
84+
# save results
85+
os.makedirs(args.output_dir, exist_ok=True)
86+
output_path = os.path.join(args.output_dir, os.path.basename(args.image_path))
87+
cv2.imwrite(output_path, origin_img)

deployment/test_image.jpg

314 KB
Loading

models/detectors/__init__.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,36 @@ def build_model(args,
1616
model_cfg,
1717
device,
1818
num_classes=80,
19-
trainable=False):
19+
trainable=False,
20+
deploy=False):
2021
# YOLOv1
2122
if args.model == 'yolov1':
2223
model, criterion = build_yolov1(
23-
args, model_cfg, device, num_classes, trainable)
24+
args, model_cfg, device, num_classes, trainable, deploy)
2425
# YOLOv2
2526
elif args.model == 'yolov2':
2627
model, criterion = build_yolov2(
27-
args, model_cfg, device, num_classes, trainable)
28+
args, model_cfg, device, num_classes, trainable, deploy)
2829
# YOLOv3
2930
elif args.model in ['yolov3', 'yolov3_t']:
3031
model, criterion = build_yolov3(
31-
args, model_cfg, device, num_classes, trainable)
32+
args, model_cfg, device, num_classes, trainable, deploy)
3233
# YOLOv4
3334
elif args.model in ['yolov4', 'yolov4_t']:
3435
model, criterion = build_yolov4(
35-
args, model_cfg, device, num_classes, trainable)
36+
args, model_cfg, device, num_classes, trainable, deploy)
3637
# YOLOv5
3738
elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
3839
model, criterion = build_yolov5(
39-
args, model_cfg, device, num_classes, trainable)
40+
args, model_cfg, device, num_classes, trainable, deploy)
4041
# YOLOv7
4142
elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
4243
model, criterion = build_yolov7(
43-
args, model_cfg, device, num_classes, trainable)
44+
args, model_cfg, device, num_classes, trainable, deploy)
4445
# YOLOX
4546
elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
4647
model, criterion = build_yolox(
47-
args, model_cfg, device, num_classes, trainable)
48+
args, model_cfg, device, num_classes, trainable, deploy)
4849

4950
if trainable:
5051
# Load pretrained weight

models/detectors/yolov1/build.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# build object detector
12-
def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
12+
def build_yolov1(args, cfg, device, num_classes=80, trainable=False, deploy=False):
1313
print('==============================')
1414
print('Build {} ...'.format(args.model.upper()))
1515

@@ -24,7 +24,8 @@ def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
2424
num_classes = num_classes,
2525
conf_thresh = args.conf_thresh,
2626
nms_thresh = args.nms_thresh,
27-
trainable = trainable
27+
trainable = trainable,
28+
deploy = deploy
2829
)
2930

3031
# -------------- Initialize YOLO --------------

models/detectors/yolov1/yolov1.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self,
1818
num_classes=20,
1919
conf_thresh=0.01,
2020
nms_thresh=0.5,
21-
trainable=False):
21+
trainable=False,
22+
deploy=False):
2223
super(YOLOv1, self).__init__()
2324
# ------------------- Basic parameters -------------------
2425
self.cfg = cfg # 模型配置文件
@@ -29,6 +30,7 @@ def __init__(self,
2930
self.conf_thresh = conf_thresh # 得分阈值
3031
self.nms_thresh = nms_thresh # NMS阈值
3132
self.stride = 32 # 网络的最大步长
33+
self.deploy = deploy
3234

3335
# ------------------- Network Structure -------------------
3436
## 主干网络
@@ -148,12 +150,18 @@ def inference(self, x):
148150
# 解算边界框, 并归一化边界框: [H*W, 4]
149151
bboxes = self.decode_boxes(reg_pred, fmp_size)
150152

151-
# 将预测放在cpu处理上,以便进行后处理
152-
scores = scores.cpu().numpy()
153-
bboxes = bboxes.cpu().numpy()
154-
155-
# 后处理
156-
bboxes, scores, labels = self.postprocess(bboxes, scores)
153+
if self.deploy:
154+
# [n_anchors_all, 4 + C]
155+
outputs = torch.cat([bboxes, scores], dim=-1)
156+
157+
return outputs
158+
else:
159+
# 将预测放在cpu处理上,以便进行后处理
160+
scores = scores.cpu().numpy()
161+
bboxes = bboxes.cpu().numpy()
162+
163+
# 后处理
164+
bboxes, scores, labels = self.postprocess(bboxes, scores)
157165

158166
return bboxes, scores, labels
159167

models/detectors/yolov2/build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# build object detector
12-
def build_yolov2(args, cfg, device, num_classes=80, trainable=False):
12+
def build_yolov2(args, cfg, device, num_classes=80, trainable=False, deploy=False):
1313
print('==============================')
1414
print('Build {} ...'.format(args.model.upper()))
1515

@@ -25,6 +25,7 @@ def build_yolov2(args, cfg, device, num_classes=80, trainable=False):
2525
conf_thresh=args.conf_thresh,
2626
nms_thresh=args.nms_thresh,
2727
topk=args.topk,
28+
deploy=deploy
2829
)
2930

3031
# -------------- Initialize YOLO --------------

models/detectors/yolov2/yolov2.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self,
1818
conf_thresh=0.01,
1919
nms_thresh=0.5,
2020
topk=100,
21-
trainable=False):
21+
trainable=False,
22+
deploy=False):
2223
super(YOLOv2, self).__init__()
2324
# ------------------- Basic parameters -------------------
2425
self.cfg = cfg # 模型配置文件
@@ -29,8 +30,9 @@ def __init__(self,
2930
self.nms_thresh = nms_thresh # NMS阈值
3031
self.topk = topk # topk
3132
self.stride = 32 # 网络的最大步长
33+
self.deploy = deploy
3234
# ------------------- Anchor box -------------------
33-
self.anchor_size = torch.as_tensor(cfg['anchor_size']).view(-1, 2) # [A, 2]
35+
self.anchor_size = torch.as_tensor(cfg['anchor_size']).float().view(-1, 2) # [A, 2]
3436
self.num_anchors = self.anchor_size.shape[0]
3537

3638
# ------------------- Network Structure -------------------
@@ -179,11 +181,19 @@ def inference(self, x):
179181
cls_pred = cls_pred[0] # [H*W*A, NC]
180182
reg_pred = reg_pred[0] # [H*W*A, 4]
181183

182-
# post process
183-
bboxes, scores, labels = self.postprocess(
184-
obj_pred, cls_pred, reg_pred, anchors)
184+
if self.deploy:
185+
scores = torch.sqrt(obj_pred.sigmoid() * cls_pred.sigmoid())
186+
bboxes = self.decode_boxes(anchors, reg_pred)
187+
# [n_anchors_all, 4 + C]
188+
outputs = torch.cat([bboxes, scores], dim=-1)
185189

186-
return bboxes, scores, labels
190+
return outputs
191+
else:
192+
# post process
193+
bboxes, scores, labels = self.postprocess(
194+
obj_pred, cls_pred, reg_pred, anchors)
195+
196+
return bboxes, scores, labels
187197

188198

189199
def forward(self, x):

models/detectors/yolov3/build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# build object detector
12-
def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
12+
def build_yolov3(args, cfg, device, num_classes=80, trainable=False, deploy=False):
1313
print('==============================')
1414
print('Build {} ...'.format(args.model.upper()))
1515

@@ -25,6 +25,7 @@ def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
2525
conf_thresh=args.conf_thresh,
2626
nms_thresh=args.nms_thresh,
2727
topk=args.topk,
28+
deploy = deploy
2829
)
2930

3031
# -------------- Initialize YOLO --------------

0 commit comments

Comments
 (0)