Skip to content

Commit b1fea92

Browse files
committed
add tensorrt support.
1 parent 3c1dc1d commit b1fea92

File tree

4 files changed

+313
-0
lines changed

4 files changed

+313
-0
lines changed

eval_tensorrt.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
import argparse
3+
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
7+
import tensorrt as trt
8+
import pycuda.autoinit
9+
import pycuda.driver as cuda
10+
11+
import torch
12+
from torch.utils.data import DataLoader
13+
from torchvision import transforms
14+
15+
from libs.dataset import Dataset
16+
17+
TRT_LOGGER = trt.Logger()
18+
19+
if __name__ == "__main__":
20+
parser = argparse.ArgumentParser()
21+
22+
parser.add_argument(
23+
'-v',
24+
'--val_path',
25+
type=str,
26+
help="Path to directory containing validation dataset.",
27+
required=True
28+
)
29+
parser.add_argument(
30+
'-o',
31+
'--out_path',
32+
type=str,
33+
help="Path for saving prediction images.",
34+
required=True
35+
)
36+
parser.add_argument(
37+
'--engine',
38+
type=str,
39+
help="Path to tensorrt engine generated by 'onnx_to_trt.py'.",
40+
required=True
41+
)
42+
43+
args = parser.parse_args()
44+
45+
os.makedirs(args.out_path, exist_ok=True)
46+
47+
val_dataset = Dataset(args.val_path, shuffle_pairs=False, augment=False)
48+
val_dataloader = DataLoader(val_dataset, batch_size=1)
49+
50+
criterion = torch.nn.BCELoss()
51+
52+
with open(args.engine, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
53+
engine = runtime.deserialize_cuda_engine(f.read())
54+
context = engine.create_execution_context()
55+
56+
device_input1, device_input2 = [None] * 2
57+
for binding in engine:
58+
if engine.binding_is_input(binding):
59+
input_shape = engine.get_binding_shape(binding)
60+
input_size = trt.volume(input_shape) * engine.max_batch_size * np.dtype(np.float32).itemsize # in bytes
61+
if device_input1 is None:
62+
device_input1 = cuda.mem_alloc(input_size)
63+
elif device_input2 is None:
64+
device_input2 = cuda.mem_alloc(input_size)
65+
else:
66+
raise Exception("Network expects more than 2 inputs.")
67+
else:
68+
output_shape = engine.get_binding_shape(binding)
69+
70+
host_output = cuda.pagelocked_empty(trt.volume(output_shape) * engine.max_batch_size, dtype=np.float32)
71+
device_output = cuda.mem_alloc(host_output.nbytes)
72+
stream = cuda.Stream()
73+
74+
losses = []
75+
correct = 0
76+
total = 0
77+
78+
inv_transform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
79+
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
80+
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
81+
std = [ 1., 1., 1. ]),
82+
])
83+
84+
for i, ((img1, img2), y, (class1, class2)) in enumerate(val_dataloader):
85+
print("[{} / {}]".format(i, len(val_dataloader)))
86+
87+
class1 = class1[0]
88+
class2 = class2[0]
89+
90+
cuda.memcpy_htod_async(device_input1, img1.numpy().astype(np.float32), stream)
91+
cuda.memcpy_htod_async(device_input2, img2.numpy().astype(np.float32), stream)
92+
93+
# run inference
94+
context.execute_async(bindings=[int(device_input1), int(device_input2), int(device_output)], stream_handle=stream.handle)
95+
cuda.memcpy_dtoh_async(host_output, device_output, stream)
96+
stream.synchronize()
97+
98+
# postprocess results
99+
prob = torch.Tensor(host_output).reshape(engine.max_batch_size, output_shape[0])
100+
101+
loss = criterion(prob, y)
102+
103+
losses.append(loss.item())
104+
correct += torch.count_nonzero(y == (prob > 0.5)).item()
105+
total += len(y)
106+
107+
fig = plt.figure("class1={}\tclass2={}".format(class1, class2), figsize=(4, 2))
108+
plt.suptitle("cls1={} conf={:.2f} cls2={}".format(class1, prob[0][0].item(), class2))
109+
110+
img1 = inv_transform(img1).cpu().numpy()[0]
111+
img2 = inv_transform(img2).cpu().numpy()[0]
112+
# show first image
113+
ax = fig.add_subplot(1, 2, 1)
114+
plt.imshow(img1[0], cmap=plt.cm.gray)
115+
plt.axis("off")
116+
117+
# show the second image
118+
ax = fig.add_subplot(1, 2, 2)
119+
plt.imshow(img2[0], cmap=plt.cm.gray)
120+
plt.axis("off")
121+
122+
# show the plot
123+
plt.savefig(os.path.join(args.out_path, '{}.png').format(i))
124+
125+
print("Validation: Loss={:.2f}\t Accuracy={:.2f}\t".format(sum(losses)/len(losses), correct / total))

infer_tensorrt.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
3+
import torch
4+
import numpy as np
5+
from PIL import Image
6+
import tensorrt as trt
7+
import pycuda.autoinit
8+
import pycuda.driver as cuda
9+
from torchvision import transforms
10+
11+
12+
# logger to capture errors, warnings, and other information during the build and inference phases
13+
TRT_LOGGER = trt.Logger()
14+
15+
feed_shape = (224, 224)
16+
17+
transform = transforms.Compose([
18+
transforms.ToTensor(),
19+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20+
transforms.Resize(feed_shape)
21+
])
22+
23+
def preprocess(filename1, filename2):
24+
image1 = Image.open(filename1).convert("RGB")
25+
image2 = Image.open(filename2).convert("RGB")
26+
27+
image1 = transform(image1).float()
28+
image2 = transform(image2).float()
29+
30+
return image1.numpy().astype(np.float32), image2.numpy().astype(np.float32)
31+
32+
if __name__ == "__main__":
33+
parser = argparse.ArgumentParser()
34+
35+
parser.add_argument(
36+
'--image1',
37+
type=str,
38+
help="Path to first image of the pair.",
39+
required=True
40+
)
41+
parser.add_argument(
42+
'--image2',
43+
type=str,
44+
help="Path to second image of the pair.",
45+
required=True
46+
)
47+
parser.add_argument(
48+
'--engine',
49+
type=str,
50+
help="Path to tensorrt engine generated by 'onnx_to_trt.py'.",
51+
required=True
52+
)
53+
54+
55+
args = parser.parse_args()
56+
57+
with open(args.engine, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
58+
engine = runtime.deserialize_cuda_engine(f.read())
59+
context = engine.create_execution_context()
60+
61+
device_input1, device_input2 = [None] * 2
62+
for binding in engine:
63+
if engine.binding_is_input(binding): # we expect only one input
64+
input_shape = engine.get_binding_shape(binding)
65+
input_size = trt.volume(input_shape) * engine.max_batch_size * np.dtype(np.float32).itemsize # in bytes
66+
if device_input1 is None:
67+
device_input1 = cuda.mem_alloc(input_size)
68+
elif device_input2 is None:
69+
device_input2 = cuda.mem_alloc(input_size)
70+
else:
71+
raise Exception("Network expects more than 2 inputs.")
72+
else: # and one output
73+
output_shape = engine.get_binding_shape(binding)
74+
# create page-locked memory buffers (i.e. won't be swapped to disk)
75+
host_output = cuda.pagelocked_empty(trt.volume(output_shape) * engine.max_batch_size, dtype=np.float32)
76+
device_output = cuda.mem_alloc(host_output.nbytes)
77+
78+
# Create a stream in which to copy inputs/outputs and run inference.
79+
stream = cuda.Stream()
80+
81+
# preprocess input data
82+
host_input = preprocess(args.image1, args.image2)
83+
cuda.memcpy_htod_async(device_input1, host_input[0], stream)
84+
cuda.memcpy_htod_async(device_input2, host_input[1], stream)
85+
86+
# run inference
87+
context.execute_async(bindings=[int(device_input1), int(device_input2), int(device_output)], stream_handle=stream.handle)
88+
cuda.memcpy_dtoh_async(host_output, device_output, stream)
89+
stream.synchronize()
90+
91+
# postprocess results
92+
output_data = torch.Tensor(host_output).reshape(engine.max_batch_size, output_shape[0])
93+
print(F"Similarity between the two images = {round(output_data[0][0].item(), 2)}")

onnx_to_trt.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pickle
2+
import argparse
3+
import tensorrt as trt
4+
5+
# logger to capture errors, warnings, and other information during the build and inference phases
6+
TRT_LOGGER = trt.Logger()
7+
8+
if __name__ == "__main__":
9+
parser = argparse.ArgumentParser()
10+
11+
parser.add_argument(
12+
'--onnx',
13+
type=str,
14+
help="Path of onnx model generated by 'torch_to_onnx.py'.",
15+
required=True
16+
)
17+
parser.add_argument(
18+
'--engine',
19+
type=str,
20+
help="Path for saving tensorrt engine.",
21+
required=True
22+
)
23+
24+
args = parser.parse_args()
25+
26+
onnx_file_path = args.onnx
27+
# initialize TensorRT engine and parse ONNX model
28+
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
29+
30+
builder = trt.Builder(TRT_LOGGER)
31+
network = builder.create_network(EXPLICIT_BATCH)
32+
parser = trt.OnnxParser(network, TRT_LOGGER)
33+
34+
# parse ONNX
35+
with open(onnx_file_path, 'rb') as model:
36+
print('Beginning ONNX file parsing')
37+
parser.parse(model.read())
38+
print('Completed parsing of ONNX file')
39+
40+
# allow TensorRT to use up to 1GB of GPU memory for tactic selection
41+
builder.max_workspace_size = 1 << 30
42+
# we have only one image in batch
43+
builder.max_batch_size = 1
44+
# use FP16 mode if possible
45+
if builder.platform_has_fast_fp16:
46+
builder.fp16_mode = True
47+
48+
# generate TensorRT engine optimized for the target platform
49+
print('Building an engine...')
50+
engine = builder.build_cuda_engine(network)
51+
print("Completed creating Engine")
52+
53+
with open(args.engine, 'wb') as f:
54+
f.write(engine.serialize())

torch_to_onnx.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import argparse
3+
4+
import onnx
5+
import torch
6+
7+
from siamese import SiameseNetwork
8+
9+
if __name__ == "__main__":
10+
parser = argparse.ArgumentParser()
11+
12+
parser.add_argument(
13+
'-c',
14+
'--checkpoint',
15+
type=str,
16+
help="Path of model checkpoint to be used for inference.",
17+
required=True
18+
)
19+
parser.add_argument(
20+
'-o',
21+
'--out_path',
22+
type=str,
23+
help="Path for saving tensorrt model.",
24+
required=True
25+
)
26+
27+
args = parser.parse_args()
28+
29+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30+
31+
checkpoint = torch.load(args.checkpoint)
32+
model = SiameseNetwork(backbone=checkpoint['backbone'])
33+
model.to(device)
34+
model.load_state_dict(checkpoint['model_state_dict'])
35+
model.eval()
36+
37+
torch.onnx.export(model, (torch.rand(1, 3, 224, 224).to(device), torch.rand(1, 3, 224, 224).to(device)), args.out_path, input_names=['input'],
38+
output_names=['output'], export_params=True)
39+
40+
onnx_model = onnx.load(args.out_path)
41+
onnx.checker.check_model(onnx_model)

0 commit comments

Comments
 (0)