Skip to content

Commit 4657036

Browse files
authored
[NeuralChat] CUDA serving with Triton Inference Server (intel#1293)
1 parent c57c17e commit 4657036

File tree

3 files changed

+244
-0
lines changed
  • intel_extension_for_transformers/neural_chat

3 files changed

+244
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Serving NeuralChat Text Generation with Triton Inference Server (CUDA)
2+
3+
Nvidia Triton Inference Server is a widely adopted inference serving software. We also support serving and deploying NeuralChat models with Triton Inference Server on CUDA devices.
4+
5+
## Prepare serving scripts
6+
7+
```
8+
cd <path to intel_extension_for_transformers>/neural_chat/examples/serving
9+
mkdir -p models/text_generation/1/
10+
cp ../../serving/triton/text_generation/cuda/model.py models/text_generation/1/model.py
11+
cp ../../serving/triton/text_generation/cuda/config.pbtxt models/text_generation/config.pbtxt
12+
```
13+
14+
15+
Then your folder structure under the current `serving` folder should be like:
16+
17+
```
18+
serving/
19+
├── models
20+
│ └── text_generation
21+
│ ├── 1
22+
│ │ ├── model.py
23+
│ └── config.pbtxt
24+
├── README.md
25+
```
26+
27+
## Start Triton Inference Server
28+
29+
```
30+
cd <path to intel_extension_for_transformers>/neural_chat/examples/serving
31+
docker run -d --gpus all -e PYTHONPATH=/opt/tritonserver/intel-extension-for-transformers --net=host -v ${PWD}/models:/models spycsh/triton_neuralchat_gpu:v2 tritonserver --model-repository=/models --http-port 8021
32+
```
33+
34+
Pass `-v` to map your model on your host machine to the docker container.
35+
36+
## Multi-card serving (optional)
37+
38+
You can also do multi-card serving to get better throughput by specifying a instance group provided by Triton Inference Server.
39+
40+
To do that, please edit the the field `instance_group` in your `config.pbtxt`.
41+
42+
One example would be like following:
43+
44+
```
45+
instance_group [
46+
{
47+
count: 1
48+
kind: KIND_GPU
49+
gpus: [0, 1]
50+
}
51+
]
52+
```
53+
54+
This means for every gpu device, we initialize an execution instance. Please check configuration details through this [link](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#multiple-model-instances).
55+
56+
## Quick check whether the server is up
57+
58+
To check whether the server is up:
59+
60+
```
61+
curl -v localhost:8021/v2/health/ready
62+
```
63+
64+
You will find a `HTTP/1.1 200 OK` if your server is up and ready for receiving requests.
65+
66+
## Use Triton client to send inference request
67+
68+
Start the Triton client and enter into the container
69+
70+
```
71+
cd <path to intel_extension_for_transformers>/neural_chat/examples/serving
72+
docker run --gpus all --net=host -it --rm -v ${PWD}/../../serving/triton/text_generation/client.py:/workspace/text_generation/client.py nvcr.io/nvidia/tritonserver:23.11-py3-sdk
73+
```
74+
75+
Send a request
76+
77+
```
78+
python /workspace/text_generation/client.py --prompt="Tell me about Intel Xeon Scalable Processors." --url=localhost:8021
79+
```
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) 2023 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
name: "text_generation"
16+
backend: "python"
17+
18+
input [
19+
{
20+
name: "INPUT0"
21+
data_type: TYPE_STRING
22+
dims: [ 1 ]
23+
}
24+
]
25+
output [
26+
{
27+
name: "OUTPUT0"
28+
data_type: TYPE_STRING
29+
dims: [ 1 ]
30+
}
31+
]
32+
33+
instance_group [{ kind: KIND_GPU }]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# !/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import json
19+
import numpy as np
20+
21+
# triton_python_backend_utils is available in every Triton Python model. You
22+
# need to use this module to create inference requests and responses. It also
23+
# contains some utility functions for extracting information from model_config
24+
# and converting Triton input/output types to numpy types.
25+
import triton_python_backend_utils as pb_utils
26+
27+
from intel_extension_for_transformers.neural_chat import build_chatbot, PipelineConfig
28+
29+
class TritonPythonModel:
30+
"""Your Python model must use the same class name. Every Python model
31+
that is created must have "TritonPythonModel" as the class name.
32+
"""
33+
34+
def initialize(self, args):
35+
"""`initialize` is called only once when the model is being loaded.
36+
Implementing `initialize` function is optional. This function allows
37+
the model to initialize any state associated with this model.
38+
39+
Parameters
40+
----------
41+
args : dict
42+
Both keys and values are strings. The dictionary keys and values are:
43+
* model_config: A JSON string containing the model configuration
44+
* model_instance_kind: A string containing model instance kind
45+
* model_instance_device_id: A string containing model instance device ID
46+
* model_repository: Model repository path
47+
* model_version: Model version
48+
* model_name: Model name
49+
"""
50+
51+
# You must parse model_config. JSON string is not parsed here
52+
self.model_config = model_config = json.loads(args["model_config"])
53+
self.model_instance_device_id = json.loads(args['model_instance_device_id'])
54+
import numba.cuda as cuda
55+
cuda.select_device(self.model_instance_device_id)
56+
57+
# Get OUTPUT0 configuration
58+
output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0")
59+
60+
# Convert Triton types to numpy types
61+
self.output0_dtype = pb_utils.triton_string_to_numpy(
62+
output0_config["data_type"]
63+
)
64+
self.config = PipelineConfig()
65+
self.chatbot = build_chatbot(self.config)
66+
67+
def execute(self, requests):
68+
"""`execute` MUST be implemented in every Python model. `execute`
69+
function receives a list of pb_utils.InferenceRequest as the only
70+
argument. This function is called when an inference request is made
71+
for this model. Depending on the batching configuration (e.g. Dynamic
72+
Batching) used, `requests` may contain multiple requests. Every
73+
Python model, must create one pb_utils.InferenceResponse for every
74+
pb_utils.InferenceRequest in `requests`. If there is an error, you can
75+
set the error argument when creating a pb_utils.InferenceResponse
76+
77+
Parameters
78+
----------
79+
requests : list
80+
A list of pb_utils.InferenceRequest
81+
82+
Returns
83+
-------
84+
list
85+
A list of pb_utils.InferenceResponse. The length of this list must
86+
be the same as `requests`
87+
"""
88+
89+
output0_dtype = self.output0_dtype
90+
chatbot = self.chatbot
91+
92+
responses = []
93+
94+
# Every Python backend must iterate over everyone of the requests
95+
# and create a pb_utils.InferenceResponse for each of them.
96+
for request in requests:
97+
# Get INPUT0
98+
in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
99+
in_0 = in_0.as_numpy()
100+
text = in_0[0].decode("utf-8")
101+
print(f"input prompt: {text}")
102+
103+
out_0 = chatbot.predict(query=text)
104+
105+
# Create output tensors. You need pb_utils.Tensor
106+
# objects to create pb_utils.InferenceResponse.
107+
out_0 = np.array(out_0)
108+
109+
out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype))
110+
111+
# Create InferenceResponse. You can set an error here in case
112+
# there was a problem with handling this inference request.
113+
# Below is an example of how you can set errors in inference
114+
# response:
115+
#
116+
# pb_utils.InferenceResponse(
117+
# output_tensors=..., TritonError("An error occurred"))
118+
inference_response = pb_utils.InferenceResponse(
119+
output_tensors=[out_tensor_0]
120+
)
121+
responses.append(inference_response)
122+
123+
# You should return a list of pb_utils.InferenceResponse. Length
124+
# of this list must match the length of `requests` list.
125+
return responses
126+
127+
def finalize(self):
128+
"""`finalize` is called only once when the model is being unloaded.
129+
Implementing `finalize` function is OPTIONAL. This function allows
130+
the model to perform any necessary clean ups before exit.
131+
"""
132+
print("Cleaning up...")

0 commit comments

Comments
 (0)