1
+ #include " yolov8_pose.hpp"
2
+ namespace tensorrt_infer
3
+ {
4
+ namespace yolov8_cuda
5
+ {
6
+ void YOLOv8Pose::initParameters (const std::string &engine_file, float score_thr, float nms_thr)
7
+ {
8
+ if (!file_exist (engine_file))
9
+ {
10
+ INFO (" Error: engine_file is not exist!!!" );
11
+ exit (0 );
12
+ }
13
+
14
+ this ->model_info = std::make_shared<ModelInfo>();
15
+ // 传入参数的配置
16
+ model_info->m_modelPath = engine_file;
17
+ model_info->m_postProcCfg .confidence_threshold_ = score_thr;
18
+ model_info->m_postProcCfg .nms_threshold_ = nms_thr;
19
+
20
+ this ->model_ = trt::infer::load (engine_file); // 加载infer对象
21
+ this ->model_ ->print (); // 打印engine的一些基本信息
22
+
23
+ // 获取输入的尺寸信息
24
+ auto input_dim = this ->model_ ->get_network_dims (0 ); // 获取输入维度信息
25
+ model_info->m_preProcCfg .infer_batch_size = input_dim[0 ];
26
+ model_info->m_preProcCfg .network_input_channels_ = input_dim[1 ];
27
+ model_info->m_preProcCfg .network_input_height_ = input_dim[2 ];
28
+ model_info->m_preProcCfg .network_input_width_ = input_dim[3 ];
29
+ model_info->m_preProcCfg .network_input_numel = input_dim[1 ] * input_dim[2 ] * input_dim[3 ];
30
+ model_info->m_preProcCfg .isdynamic_model_ = this ->model_ ->has_dynamic_dim ();
31
+ // 对输入的图片预处理进行配置,即,yolov8的预处理是除以255,并且是RGB通道输入
32
+ model_info->m_preProcCfg .normalize_ = Norm::alpha_beta (1 / 255 .0f , 0 .0f , ChannelType::RGB);
33
+
34
+ // 获取输出的尺寸信息
35
+ auto output_dim = this ->model_ ->get_network_dims (1 );
36
+ model_info->m_postProcCfg .bbox_head_dims_ = output_dim;
37
+ model_info->m_postProcCfg .bbox_head_dims_output_numel_ = output_dim[1 ] * output_dim[2 ];
38
+ if (model_info->m_postProcCfg .pose_num_ == 0 )
39
+ model_info->m_postProcCfg .pose_num_ = (int )((output_dim[2 ] - 5 ) / 3 ); // yolov8 pose,5:xmin,ymin,xmax,ymax,score
40
+ model_info->m_postProcCfg .NUM_BOX_ELEMENT += model_info->m_postProcCfg .pose_num_ * 3 ; // 3:pose_x,pose_y,pose_score
41
+ model_info->m_postProcCfg .IMAGE_MAX_BOXES_ADD_ELEMENT = model_info->m_postProcCfg .MAX_IMAGE_BOXES * model_info->m_postProcCfg .NUM_BOX_ELEMENT ;
42
+
43
+ CHECK (cudaStreamCreate (&cu_stream)); // 创建cuda流
44
+ }
45
+
46
+ YOLOv8Pose::~YOLOv8Pose ()
47
+ {
48
+ CHECK (cudaStreamDestroy (cu_stream)); // 销毁cuda流
49
+ }
50
+
51
+ void YOLOv8Pose::adjust_memory (int batch_size)
52
+ {
53
+ // 申请模型输入和模型输出所用到的内存
54
+ input_buffer_.gpu (batch_size * model_info->m_preProcCfg .network_input_numel ); // 申请batch个模型输入的gpu内存
55
+ bbox_predict_.gpu (batch_size * model_info->m_postProcCfg .bbox_head_dims_output_numel_ ); // 申请batch个模型输出的gpu内存
56
+
57
+ // 申请模型解析成box时需要存储的内存,,+32是因为第一个数要设置为框的个数,防止内存溢出
58
+ output_boxarray_.gpu (batch_size * (32 + model_info->m_postProcCfg .IMAGE_MAX_BOXES_ADD_ELEMENT ));
59
+ output_boxarray_.cpu (batch_size * (32 + model_info->m_postProcCfg .IMAGE_MAX_BOXES_ADD_ELEMENT ));
60
+
61
+ if ((int )preprocess_buffers_.size () < batch_size)
62
+ {
63
+ for (int i = preprocess_buffers_.size (); i < batch_size; ++i)
64
+ preprocess_buffers_.push_back (make_shared<Memory<unsigned char >>()); // 添加batch个Memory对象
65
+ }
66
+
67
+ // 申请batch size个仿射矩阵,由于也是动态batch指定,所以直接在这里写了
68
+ if ((int )affine_matrixs.size () < batch_size)
69
+ {
70
+ for (int i = affine_matrixs.size (); i < batch_size; ++i)
71
+ affine_matrixs.push_back (AffineMatrix ()); // 添加batch个AffineMatrix对象
72
+ }
73
+ }
74
+
75
+ void YOLOv8Pose::preprocess_gpu (int ibatch, const Image &image,
76
+ shared_ptr<Memory<unsigned char >> preprocess_buffer, AffineMatrix &affine,
77
+ cudaStream_t stream_)
78
+ {
79
+ if (image.channels != model_info->m_preProcCfg .network_input_channels_ )
80
+ {
81
+ INFO (" Warning : Number of channels wanted differs from number of channels in the actual image \n " );
82
+ exit (-1 );
83
+ }
84
+
85
+ affine.compute (make_tuple (image.width , image.height ),
86
+ make_tuple (model_info->m_preProcCfg .network_input_width_ , model_info->m_preProcCfg .network_input_height_ ));
87
+ float *input_device = input_buffer_.gpu () + ibatch * model_info->m_preProcCfg .network_input_numel ; // 获取当前batch的gpu内存指针
88
+ size_t size_image = image.width * image.height * image.channels ;
89
+ size_t size_matrix = upbound (sizeof (affine.d2i ), 32 ); // 向上取整
90
+ uint8_t *gpu_workspace = preprocess_buffer->gpu (size_matrix + size_image); // 这里把仿射矩阵+image_size放在一起申请gpu内存
91
+ float *affine_matrix_device = (float *)gpu_workspace;
92
+ uint8_t *image_device = gpu_workspace + size_matrix; // 这里只取仿射变换矩阵的gpu内存
93
+
94
+ // 同上,只不过申请的是cpu内存
95
+ uint8_t *cpu_workspace = preprocess_buffer->cpu (size_matrix + size_image);
96
+ float *affine_matrix_host = (float *)cpu_workspace;
97
+ uint8_t *image_host = cpu_workspace + size_matrix;
98
+
99
+ // 赋值这一步并不是多余的,这个是从分页内存到固定页内存的数据传输,可以加速向gpu内存进行数据传输
100
+ memcpy (image_host, image.bgrptr , size_image); // 给图片内存赋值
101
+ memcpy (affine_matrix_host, affine.d2i , sizeof (affine.d2i )); // 给仿射变换矩阵内存赋值
102
+
103
+ // 从cpu-->gpu,其中image_host也可以替换为image.bgrptr然后删除上面几行,但会慢个0.02ms左右
104
+ checkRuntime (cudaMemcpyAsync (image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_)); // 图片 cpu内存上传到gpu内存
105
+ checkRuntime (cudaMemcpyAsync (affine_matrix_device, affine_matrix_host, sizeof (affine.d2i ),
106
+ cudaMemcpyHostToDevice, stream_)); // 仿射变换矩阵 cpu内存上传到gpu内存
107
+ // 执行resize+fill[114]
108
+ warp_affine_bilinear_and_normalize_plane (image_device, image.width * image.channels , image.width ,
109
+ image.height , input_device, model_info->m_preProcCfg .network_input_width_ ,
110
+ model_info->m_preProcCfg .network_input_height_ , affine_matrix_device, const_value,
111
+ model_info->m_preProcCfg .normalize_ , stream_);
112
+ }
113
+
114
+ void YOLOv8Pose::postprocess_gpu (int ibatch, cudaStream_t stream_)
115
+ {
116
+ // boxarray_device:对推理结果进行解析后要存储的gpu指针
117
+ float *boxarray_device = output_boxarray_.gpu () + ibatch * (32 + model_info->m_postProcCfg .IMAGE_MAX_BOXES_ADD_ELEMENT );
118
+ // affine_matrix_device:获取仿射变换矩阵+size_image的gpu指针,主要是用来是的归一化的框尺寸放缩至相对于图片尺寸
119
+ float *affine_matrix_device = (float *)preprocess_buffers_[ibatch]->gpu ();
120
+ // image_based_bbox_output:推理结果产生的所有预测框的gpu指针
121
+ float *image_based_bbox_output = bbox_predict_.gpu () + ibatch * model_info->m_postProcCfg .bbox_head_dims_output_numel_ ;
122
+
123
+ checkRuntime (cudaMemsetAsync (boxarray_device, 0 , sizeof (int ), stream_));
124
+ decode_pose_yolov8_kernel_invoker (image_based_bbox_output, model_info->m_postProcCfg .bbox_head_dims_ [1 ], model_info->m_postProcCfg .pose_num_ ,
125
+ model_info->m_postProcCfg .bbox_head_dims_ [2 ], model_info->m_postProcCfg .confidence_threshold_ ,
126
+ affine_matrix_device, boxarray_device, model_info->m_postProcCfg .MAX_IMAGE_BOXES ,
127
+ model_info->m_postProcCfg .NUM_BOX_ELEMENT , stream_);
128
+
129
+ // 对筛选后的框进行nms操作
130
+ nms_kernel_invoker (boxarray_device, model_info->m_postProcCfg .nms_threshold_ , model_info->m_postProcCfg .MAX_IMAGE_BOXES ,
131
+ model_info->m_postProcCfg .NUM_BOX_ELEMENT , stream_);
132
+ }
133
+
134
+ BatchPoseBoxArray YOLOv8Pose::parser_box (int num_image)
135
+ {
136
+ BatchPoseBoxArray arrout (num_image);
137
+ for (int ib = 0 ; ib < num_image; ++ib)
138
+ {
139
+ float *parray = output_boxarray_.cpu () + ib * (32 + model_info->m_postProcCfg .IMAGE_MAX_BOXES_ADD_ELEMENT );
140
+ int count = min (model_info->m_postProcCfg .MAX_IMAGE_BOXES , (int )*parray);
141
+ PoseBoxArray &output = arrout[ib];
142
+ output.reserve (count); // 增加vector的容量大于或等于count的值
143
+ for (int i = 0 ; i < count; ++i)
144
+ {
145
+ float *pbox = parray + 1 + i * model_info->m_postProcCfg .NUM_BOX_ELEMENT ;
146
+ int label = pbox[5 ];
147
+ int keepflag = pbox[6 ];
148
+ if (keepflag == 1 )
149
+ {
150
+ PoseBox result_object_box (pbox[0 ], pbox[1 ], pbox[2 ], pbox[3 ], pbox[4 ], label);
151
+ result_object_box.pose = make_shared<InstancePose>();
152
+ for (int pindex = 7 ; pindex < model_info->m_postProcCfg .NUM_BOX_ELEMENT ; pindex += 3 )
153
+ result_object_box.pose ->pose_data .push_back ({pbox[pindex], pbox[pindex + 1 ], pbox[pindex + 2 ]});
154
+ output.emplace_back (result_object_box);
155
+ }
156
+ }
157
+ }
158
+
159
+ return arrout;
160
+ }
161
+
162
+ PoseBoxArray YOLOv8Pose::forward (const Image &image)
163
+ {
164
+ auto output = forwards ({image});
165
+ if (output.empty ())
166
+ return {};
167
+ return output[0 ];
168
+ }
169
+
170
+ BatchPoseBoxArray YOLOv8Pose::forwards (const std::vector<Image> &images)
171
+ {
172
+ int num_image = images.size ();
173
+ if (num_image == 0 )
174
+ return {};
175
+
176
+ // 动态设置batch size
177
+ auto input_dims = model_->get_network_dims (0 );
178
+ if (model_info->m_preProcCfg .infer_batch_size != num_image)
179
+ {
180
+ if (model_info->m_preProcCfg .isdynamic_model_ )
181
+ {
182
+ model_info->m_preProcCfg .infer_batch_size = num_image;
183
+ input_dims[0 ] = num_image;
184
+ if (!model_->set_network_dims (0 , input_dims)) // 重新绑定输入batch,返回值类型是bool
185
+ return {};
186
+ }
187
+ else
188
+ {
189
+ if (model_info->m_preProcCfg .infer_batch_size < num_image)
190
+ {
191
+ INFO (
192
+ " When using static shape model, number of images[%d] must be "
193
+ " less than or equal to the maximum batch[%d]." ,
194
+ num_image, model_info->m_preProcCfg .infer_batch_size );
195
+ return {};
196
+ }
197
+ }
198
+ }
199
+
200
+ // 由于batch size是动态的,所以需要对gpu/cpu内存进行动态的申请
201
+ adjust_memory (model_info->m_preProcCfg .infer_batch_size );
202
+
203
+ // 对图片进行预处理
204
+ for (int i = 0 ; i < num_image; ++i)
205
+ preprocess_gpu (i, images[i], preprocess_buffers_[i], affine_matrixs[i], cu_stream); // input_buffer_会获取到图片预处理好的值
206
+
207
+ // 推理模型
208
+ float *bbox_output_device = bbox_predict_.gpu (); // 获取推理后要存储结果的gpu指针
209
+ vector<void *> bindings{input_buffer_.gpu (), bbox_output_device}; // 绑定bindings作为输入进行forward
210
+ if (!model_->forward (bindings, cu_stream))
211
+ {
212
+ INFO (" Failed to tensorRT forward." );
213
+ return {};
214
+ }
215
+
216
+ // 对推理结果进行解析
217
+ for (int ib = 0 ; ib < num_image; ++ib)
218
+ postprocess_gpu (ib, cu_stream);
219
+
220
+ // 将nms后的框结果从gpu内存传递到cpu内存
221
+ checkRuntime (cudaMemcpyAsync (output_boxarray_.cpu (), output_boxarray_.gpu (),
222
+ output_boxarray_.gpu_bytes (), cudaMemcpyDeviceToHost, cu_stream));
223
+ checkRuntime (cudaStreamSynchronize (cu_stream)); // 阻塞异步流,等流中所有操作执行完成才会继续执行
224
+
225
+ return parser_box (num_image);
226
+ }
227
+ }
228
+ }
0 commit comments