Skip to content

Commit 9503639

Browse files
HexToStringjiangjiajunheliqi
authored
[Serving] add ocr serving example (#627)
* add ocr serving example * 1 1 * Add files via upload * Update README.md * Delete ocr_pipeline.png * Add files via upload * Delete ocr_pipeline.png * Add files via upload * 1 1 * 1 1 * Update README.md * 1 1 * fix codestyle * fix codestyle Co-authored-by: Jason <jiangjiajun@baidu.com> Co-authored-by: heliqi <1101791222@qq.com>
1 parent c721773 commit 9503639

File tree

17 files changed

+1188
-0
lines changed

17 files changed

+1188
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# PP-OCR服务化部署示例
2+
3+
## 介绍
4+
本文介绍了使用FastDeploy搭建OCR文字识别服务的方法.
5+
6+
服务端必须在docker内启动,而客户端不是必须在docker容器内.
7+
8+
**本文所在路径($PWD)下的models里包含模型的配置和代码(服务端会加载模型和代码以启动服务), 需要将其映射到docker中使用.**
9+
10+
OCR由det(检测)、cls(分类)和rec(识别)三个模型组成.
11+
12+
服务化部署串联的示意图如下图所示,其中`pp_ocr`串联了`det_preprocess``det_runtime``det_postprocess`,`cls_pp`串联了`cls_runtime``cls_postprocess`,`rec_pp`串联了`rec_runtime``rec_postprocess`.
13+
14+
特别的是,在`det_postprocess`中会多次调用`cls_pp``rec_pp`服务,来实现对检测结果(多个框)进行分类和识别,,最后返回给用户最终的识别结果。
15+
16+
<p align="center">
17+
<br>
18+
<img src='./ppocr.png'">
19+
<br>
20+
<p>
21+
22+
## 使用
23+
### 1. 服务端
24+
#### 1.1 Docker
25+
```bash
26+
# 下载仓库代码
27+
git clone https://github.com/PaddlePaddle/FastDeploy.git
28+
cd FastDeploy/examples/vision/ocr/PP-OCRv3/serving/
29+
30+
# 下载模型,图片和字典文件
31+
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
32+
tar xvf ch_PP-OCRv3_det_infer.tar && mv ch_PP-OCRv3_det_infer 1
33+
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
34+
mv 1 models/det_runtime/ && rm -rf ch_PP-OCRv3_det_infer.tar
35+
36+
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
37+
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar && mv ch_ppocr_mobile_v2.0_cls_infer 1
38+
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
39+
mv 1 models/cls_runtime/ && rm -rf ch_ppocr_mobile_v2.0_cls_infer.tar
40+
41+
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
42+
tar xvf ch_PP-OCRv3_rec_infer.tar && mv ch_PP-OCRv3_rec_infer 1
43+
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
44+
mv 1 models/rec_runtime/ && rm -rf ch_PP-OCRv3_rec_infer.tar
45+
46+
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
47+
mv ppocr_keys_v1.txt models/rec_postprocess/1/
48+
49+
wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg
50+
51+
52+
docker pull paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10
53+
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/ocr_serving paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10 bash
54+
docker exec -it -u root fastdeploy bash
55+
```
56+
57+
#### 1.2 安装(在docker内)
58+
```bash
59+
ldconfig
60+
apt-get install libgl1
61+
```
62+
63+
#### 1.3 启动服务端(在docker内)
64+
```bash
65+
fastdeployserver --model-repository=/ocr_serving/models
66+
```
67+
68+
参数:
69+
- `model-repository`(required): 整套模型streaming_pp_tts存放的路径.
70+
- `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口.
71+
- `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`.
72+
- `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口.
73+
74+
75+
### 2. 客户端
76+
#### 2.1 安装
77+
```bash
78+
pip3 install tritonclient[all]
79+
```
80+
81+
#### 2.2 发送请求
82+
```bash
83+
python3 client.py
84+
```
85+
86+
## 配置修改
87+
88+
当前默认配置在GPU上运行, 如果要在CPU或其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import logging
2+
import numpy as np
3+
import time
4+
from typing import Optional
5+
import cv2
6+
import json
7+
8+
from tritonclient import utils as client_utils
9+
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
10+
11+
LOGGER = logging.getLogger("run_inference_on_triton")
12+
13+
14+
class SyncGRPCTritonRunner:
15+
DEFAULT_MAX_RESP_WAIT_S = 120
16+
17+
def __init__(
18+
self,
19+
server_url: str,
20+
model_name: str,
21+
model_version: str,
22+
*,
23+
verbose=False,
24+
resp_wait_s: Optional[float]=None, ):
25+
self._server_url = server_url
26+
self._model_name = model_name
27+
self._model_version = model_version
28+
self._verbose = verbose
29+
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
30+
31+
self._client = InferenceServerClient(
32+
self._server_url, verbose=self._verbose)
33+
error = self._verify_triton_state(self._client)
34+
if error:
35+
raise RuntimeError(
36+
f"Could not communicate to Triton Server: {error}")
37+
38+
LOGGER.debug(
39+
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
40+
f"are up and ready!")
41+
42+
model_config = self._client.get_model_config(self._model_name,
43+
self._model_version)
44+
model_metadata = self._client.get_model_metadata(self._model_name,
45+
self._model_version)
46+
LOGGER.info(f"Model config {model_config}")
47+
LOGGER.info(f"Model metadata {model_metadata}")
48+
49+
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
50+
self._input_names = list(self._inputs)
51+
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
52+
self._output_names = list(self._outputs)
53+
self._outputs_req = [
54+
InferRequestedOutput(name) for name in self._outputs
55+
]
56+
57+
def Run(self, inputs):
58+
"""
59+
Args:
60+
inputs: list, Each value corresponds to an input name of self._input_names
61+
Returns:
62+
results: dict, {name : numpy.array}
63+
"""
64+
infer_inputs = []
65+
for idx, data in enumerate(inputs):
66+
infer_input = InferInput(self._input_names[idx], data.shape,
67+
"UINT8")
68+
infer_input.set_data_from_numpy(data)
69+
infer_inputs.append(infer_input)
70+
71+
results = self._client.infer(
72+
model_name=self._model_name,
73+
model_version=self._model_version,
74+
inputs=infer_inputs,
75+
outputs=self._outputs_req,
76+
client_timeout=self._response_wait_t, )
77+
results = {name: results.as_numpy(name) for name in self._output_names}
78+
return results
79+
80+
def _verify_triton_state(self, triton_client):
81+
if not triton_client.is_server_live():
82+
return f"Triton server {self._server_url} is not live"
83+
elif not triton_client.is_server_ready():
84+
return f"Triton server {self._server_url} is not ready"
85+
elif not triton_client.is_model_ready(self._model_name,
86+
self._model_version):
87+
return f"Model {self._model_name}:{self._model_version} is not ready"
88+
return None
89+
90+
91+
if __name__ == "__main__":
92+
model_name = "pp_ocr"
93+
model_version = "1"
94+
url = "localhost:9001"
95+
runner = SyncGRPCTritonRunner(url, model_name, model_version)
96+
im = cv2.imread("12.jpg")
97+
im = np.array([im, ])
98+
for i in range(1):
99+
result = runner.Run([im, ])
100+
batch_texts = result['rec_texts']
101+
batch_scores = result['rec_scores']
102+
for i_batch in range(len(batch_texts)):
103+
texts = batch_texts[i_batch]
104+
scores = batch_scores[i_batch]
105+
for i_box in range(len(texts)):
106+
print('text=', texts[i_box].decode('utf-8'), ' score=',
107+
scores[i_box])
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import numpy as np
17+
import time
18+
19+
import fastdeploy as fd
20+
21+
# triton_python_backend_utils is available in every Triton Python model. You
22+
# need to use this module to create inference requests and responses. It also
23+
# contains some utility functions for extracting information from model_config
24+
# and converting Triton input/output types to numpy types.
25+
import triton_python_backend_utils as pb_utils
26+
27+
28+
class TritonPythonModel:
29+
"""Your Python model must use the same class name. Every Python model
30+
that is created must have "TritonPythonModel" as the class name.
31+
"""
32+
33+
def initialize(self, args):
34+
"""`initialize` is called only once when the model is being loaded.
35+
Implementing `initialize` function is optional. This function allows
36+
the model to intialize any state associated with this model.
37+
Parameters
38+
----------
39+
args : dict
40+
Both keys and values are strings. The dictionary keys and values are:
41+
* model_config: A JSON string containing the model configuration
42+
* model_instance_kind: A string containing model instance kind
43+
* model_instance_device_id: A string containing model instance device ID
44+
* model_repository: Model repository path
45+
* model_version: Model version
46+
* model_name: Model name
47+
"""
48+
# You must parse model_config. JSON string is not parsed here
49+
self.model_config = json.loads(args['model_config'])
50+
print("model_config:", self.model_config)
51+
52+
self.input_names = []
53+
for input_config in self.model_config["input"]:
54+
self.input_names.append(input_config["name"])
55+
print("postprocess input names:", self.input_names)
56+
57+
self.output_names = []
58+
self.output_dtype = []
59+
for output_config in self.model_config["output"]:
60+
self.output_names.append(output_config["name"])
61+
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
62+
self.output_dtype.append(dtype)
63+
print("postprocess output names:", self.output_names)
64+
self.postprocessor = fd.vision.ocr.ClassifierPostprocessor()
65+
66+
def execute(self, requests):
67+
"""`execute` must be implemented in every Python model. `execute`
68+
function receives a list of pb_utils.InferenceRequest as the only
69+
argument. This function is called when an inference is requested
70+
for this model. Depending on the batching configuration (e.g. Dynamic
71+
Batching) used, `requests` may contain multiple requests. Every
72+
Python model, must create one pb_utils.InferenceResponse for every
73+
pb_utils.InferenceRequest in `requests`. If there is an error, you can
74+
set the error argument when creating a pb_utils.InferenceResponse.
75+
Parameters
76+
----------
77+
requests : list
78+
A list of pb_utils.InferenceRequest
79+
Returns
80+
-------
81+
list
82+
A list of pb_utils.InferenceResponse. The length of this list must
83+
be the same as `requests`
84+
"""
85+
responses = []
86+
for request in requests:
87+
infer_outputs = pb_utils.get_input_tensor_by_name(
88+
request, self.input_names[0])
89+
infer_outputs = infer_outputs.as_numpy()
90+
results = self.postprocessor.run([infer_outputs])
91+
out_tensor_0 = pb_utils.Tensor(self.output_names[0],
92+
np.array(results[0]))
93+
out_tensor_1 = pb_utils.Tensor(self.output_names[1],
94+
np.array(results[1]))
95+
inference_response = pb_utils.InferenceResponse(
96+
output_tensors=[out_tensor_0, out_tensor_1])
97+
responses.append(inference_response)
98+
return responses
99+
100+
def finalize(self):
101+
"""`finalize` is called only once when the model is being unloaded.
102+
Implementing `finalize` function is optional. This function allows
103+
the model to perform any necessary clean ups before exit.
104+
"""
105+
print('Cleaning up...')
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: "cls_postprocess"
2+
backend: "python"
3+
max_batch_size: 128
4+
input [
5+
{
6+
name: "POST_INPUT_0"
7+
data_type: TYPE_FP32
8+
dims: [ 2 ]
9+
}
10+
]
11+
12+
output [
13+
{
14+
name: "POST_OUTPUT_0"
15+
data_type: TYPE_INT32
16+
dims: [ 1 ]
17+
},
18+
{
19+
name: "POST_OUTPUT_1"
20+
data_type: TYPE_FP32
21+
dims: [ 1 ]
22+
}
23+
]
24+
25+
instance_group [
26+
{
27+
count: 1
28+
kind: KIND_CPU
29+
}
30+
]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
name: "cls_pp"
2+
platform: "ensemble"
3+
max_batch_size: 128
4+
input [
5+
{
6+
name: "x"
7+
data_type: TYPE_FP32
8+
dims: [ 3, -1, -1 ]
9+
}
10+
]
11+
output [
12+
{
13+
name: "cls_labels"
14+
data_type: TYPE_INT32
15+
dims: [ 1 ]
16+
},
17+
{
18+
name: "cls_scores"
19+
data_type: TYPE_FP32
20+
dims: [ 1 ]
21+
}
22+
]
23+
ensemble_scheduling {
24+
step [
25+
{
26+
model_name: "cls_runtime"
27+
model_version: 1
28+
input_map {
29+
key: "x"
30+
value: "x"
31+
}
32+
output_map {
33+
key: "softmax_0.tmp_0"
34+
value: "infer_output"
35+
}
36+
},
37+
{
38+
model_name: "cls_postprocess"
39+
model_version: 1
40+
input_map {
41+
key: "POST_INPUT_0"
42+
value: "infer_output"
43+
}
44+
output_map {
45+
key: "POST_OUTPUT_0"
46+
value: "cls_labels"
47+
}
48+
output_map {
49+
key: "POST_OUTPUT_1"
50+
value: "cls_scores"
51+
}
52+
}
53+
]
54+
}

0 commit comments

Comments
 (0)