Skip to content

Commit

Permalink
[Serving] add ocr serving example (PaddlePaddle#627)
Browse files Browse the repository at this point in the history
* add ocr serving example

* 1

1

* Add files via upload

* Update README.md

* Delete ocr_pipeline.png

* Add files via upload

* Delete ocr_pipeline.png

* Add files via upload

* 1

1

* 1

1

* Update README.md

* 1

1

* fix codestyle

* fix codestyle

Co-authored-by: Jason <[email protected]>
Co-authored-by: heliqi <[email protected]>
  • Loading branch information
3 people authored Nov 28, 2022
1 parent c721773 commit 9503639
Show file tree
Hide file tree
Showing 17 changed files with 1,188 additions and 0 deletions.
88 changes: 88 additions & 0 deletions examples/vision/ocr/PP-OCRv3/serving/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# PP-OCR服务化部署示例

## 介绍
本文介绍了使用FastDeploy搭建OCR文字识别服务的方法.

服务端必须在docker内启动,而客户端不是必须在docker容器内.

**本文所在路径($PWD)下的models里包含模型的配置和代码(服务端会加载模型和代码以启动服务), 需要将其映射到docker中使用.**

OCR由det(检测)、cls(分类)和rec(识别)三个模型组成.

服务化部署串联的示意图如下图所示,其中`pp_ocr`串联了`det_preprocess``det_runtime``det_postprocess`,`cls_pp`串联了`cls_runtime``cls_postprocess`,`rec_pp`串联了`rec_runtime``rec_postprocess`.

特别的是,在`det_postprocess`中会多次调用`cls_pp``rec_pp`服务,来实现对检测结果(多个框)进行分类和识别,,最后返回给用户最终的识别结果。

<p align="center">
<br>
<img src='./ppocr.png'">
<br>
<p>

## 使用
### 1. 服务端
#### 1.1 Docker
```bash
# 下载仓库代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/ocr/PP-OCRv3/serving/

# 下载模型,图片和字典文件
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
tar xvf ch_PP-OCRv3_det_infer.tar && mv ch_PP-OCRv3_det_infer 1
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
mv 1 models/det_runtime/ && rm -rf ch_PP-OCRv3_det_infer.tar

wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xvf ch_ppocr_mobile_v2.0_cls_infer.tar && mv ch_ppocr_mobile_v2.0_cls_infer 1
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
mv 1 models/cls_runtime/ && rm -rf ch_ppocr_mobile_v2.0_cls_infer.tar

wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar
tar xvf ch_PP-OCRv3_rec_infer.tar && mv ch_PP-OCRv3_rec_infer 1
mv 1/inference.pdiparams 1/model.pdiparams && mv 1/inference.pdmodel 1/model.pdmodel
mv 1 models/rec_runtime/ && rm -rf ch_PP-OCRv3_rec_infer.tar

wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/ppocr/utils/ppocr_keys_v1.txt
mv ppocr_keys_v1.txt models/rec_postprocess/1/

wget https://gitee.com/paddlepaddle/PaddleOCR/raw/release/2.6/doc/imgs/12.jpg


docker pull paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10
docker run -dit --net=host --name fastdeploy --shm-size="1g" -v $PWD:/ocr_serving paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10 bash
docker exec -it -u root fastdeploy bash
```

#### 1.2 安装(在docker内)
```bash
ldconfig
apt-get install libgl1
```

#### 1.3 启动服务端(在docker内)
```bash
fastdeployserver --model-repository=/ocr_serving/models
```

参数:
- `model-repository`(required): 整套模型streaming_pp_tts存放的路径.
- `http-port`(optional): HTTP服务的端口号. 默认: `8000`. 本示例中未使用该端口.
- `grpc-port`(optional): GRPC服务的端口号. 默认: `8001`.
- `metrics-port`(optional): 服务端指标的端口号. 默认: `8002`. 本示例中未使用该端口.


### 2. 客户端
#### 2.1 安装
```bash
pip3 install tritonclient[all]
```

#### 2.2 发送请求
```bash
python3 client.py
```

## 配置修改

当前默认配置在GPU上运行, 如果要在CPU或其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](../../../../../serving/docs/zh_CN/model_configuration.md)
107 changes: 107 additions & 0 deletions examples/vision/ocr/PP-OCRv3/serving/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import logging
import numpy as np
import time
from typing import Optional
import cv2
import json

from tritonclient import utils as client_utils
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2

LOGGER = logging.getLogger("run_inference_on_triton")


class SyncGRPCTritonRunner:
DEFAULT_MAX_RESP_WAIT_S = 120

def __init__(
self,
server_url: str,
model_name: str,
model_version: str,
*,
verbose=False,
resp_wait_s: Optional[float]=None, ):
self._server_url = server_url
self._model_name = model_name
self._model_version = model_version
self._verbose = verbose
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s

self._client = InferenceServerClient(
self._server_url, verbose=self._verbose)
error = self._verify_triton_state(self._client)
if error:
raise RuntimeError(
f"Could not communicate to Triton Server: {error}")

LOGGER.debug(
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
f"are up and ready!")

model_config = self._client.get_model_config(self._model_name,
self._model_version)
model_metadata = self._client.get_model_metadata(self._model_name,
self._model_version)
LOGGER.info(f"Model config {model_config}")
LOGGER.info(f"Model metadata {model_metadata}")

self._inputs = {tm.name: tm for tm in model_metadata.inputs}
self._input_names = list(self._inputs)
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
self._output_names = list(self._outputs)
self._outputs_req = [
InferRequestedOutput(name) for name in self._outputs
]

def Run(self, inputs):
"""
Args:
inputs: list, Each value corresponds to an input name of self._input_names
Returns:
results: dict, {name : numpy.array}
"""
infer_inputs = []
for idx, data in enumerate(inputs):
infer_input = InferInput(self._input_names[idx], data.shape,
"UINT8")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)

results = self._client.infer(
model_name=self._model_name,
model_version=self._model_version,
inputs=infer_inputs,
outputs=self._outputs_req,
client_timeout=self._response_wait_t, )
results = {name: results.as_numpy(name) for name in self._output_names}
return results

def _verify_triton_state(self, triton_client):
if not triton_client.is_server_live():
return f"Triton server {self._server_url} is not live"
elif not triton_client.is_server_ready():
return f"Triton server {self._server_url} is not ready"
elif not triton_client.is_model_ready(self._model_name,
self._model_version):
return f"Model {self._model_name}:{self._model_version} is not ready"
return None


if __name__ == "__main__":
model_name = "pp_ocr"
model_version = "1"
url = "localhost:9001"
runner = SyncGRPCTritonRunner(url, model_name, model_version)
im = cv2.imread("12.jpg")
im = np.array([im, ])
for i in range(1):
result = runner.Run([im, ])
batch_texts = result['rec_texts']
batch_scores = result['rec_scores']
for i_batch in range(len(batch_texts)):
texts = batch_texts[i_batch]
scores = batch_scores[i_batch]
for i_box in range(len(texts)):
print('text=', texts[i_box].decode('utf-8'), ' score=',
scores[i_box])
105 changes: 105 additions & 0 deletions examples/vision/ocr/PP-OCRv3/serving/models/cls_postprocess/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import numpy as np
import time

import fastdeploy as fd

# triton_python_backend_utils is available in every Triton Python model. You
# need to use this module to create inference requests and responses. It also
# contains some utility functions for extracting information from model_config
# and converting Triton input/output types to numpy types.
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""

def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to intialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
# You must parse model_config. JSON string is not parsed here
self.model_config = json.loads(args['model_config'])
print("model_config:", self.model_config)

self.input_names = []
for input_config in self.model_config["input"]:
self.input_names.append(input_config["name"])
print("postprocess input names:", self.input_names)

self.output_names = []
self.output_dtype = []
for output_config in self.model_config["output"]:
self.output_names.append(output_config["name"])
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
self.output_dtype.append(dtype)
print("postprocess output names:", self.output_names)
self.postprocessor = fd.vision.ocr.ClassifierPostprocessor()

def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference is requested
for this model. Depending on the batching configuration (e.g. Dynamic
Batching) used, `requests` may contain multiple requests. Every
Python model, must create one pb_utils.InferenceResponse for every
pb_utils.InferenceRequest in `requests`. If there is an error, you can
set the error argument when creating a pb_utils.InferenceResponse.
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
responses = []
for request in requests:
infer_outputs = pb_utils.get_input_tensor_by_name(
request, self.input_names[0])
infer_outputs = infer_outputs.as_numpy()
results = self.postprocessor.run([infer_outputs])
out_tensor_0 = pb_utils.Tensor(self.output_names[0],
np.array(results[0]))
out_tensor_1 = pb_utils.Tensor(self.output_names[1],
np.array(results[1]))
inference_response = pb_utils.InferenceResponse(
output_tensors=[out_tensor_0, out_tensor_1])
responses.append(inference_response)
return responses

def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
print('Cleaning up...')
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: "cls_postprocess"
backend: "python"
max_batch_size: 128
input [
{
name: "POST_INPUT_0"
data_type: TYPE_FP32
dims: [ 2 ]
}
]

output [
{
name: "POST_OUTPUT_0"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "POST_OUTPUT_1"
data_type: TYPE_FP32
dims: [ 1 ]
}
]

instance_group [
{
count: 1
kind: KIND_CPU
}
]
54 changes: 54 additions & 0 deletions examples/vision/ocr/PP-OCRv3/serving/models/cls_pp/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: "cls_pp"
platform: "ensemble"
max_batch_size: 128
input [
{
name: "x"
data_type: TYPE_FP32
dims: [ 3, -1, -1 ]
}
]
output [
{
name: "cls_labels"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "cls_scores"
data_type: TYPE_FP32
dims: [ 1 ]
}
]
ensemble_scheduling {
step [
{
model_name: "cls_runtime"
model_version: 1
input_map {
key: "x"
value: "x"
}
output_map {
key: "softmax_0.tmp_0"
value: "infer_output"
}
},
{
model_name: "cls_postprocess"
model_version: 1
input_map {
key: "POST_INPUT_0"
value: "infer_output"
}
output_map {
key: "POST_OUTPUT_0"
value: "cls_labels"
}
output_map {
key: "POST_OUTPUT_1"
value: "cls_scores"
}
}
]
}
Loading

0 comments on commit 9503639

Please sign in to comment.