Skip to content

Commit

Permalink
[Serving] Support FastDeploy XPU Triton Server (PaddlePaddle#1994)
Browse files Browse the repository at this point in the history
* [patchelf] fix patchelf error for inference xpu

* [serving] add xpu dockerfile and support fd server

* [serving] add xpu dockerfile and support fd server

* [Serving] support XPU + Tritron

* [Serving] support XPU + Tritron

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] add comments for xpu tritron dockerfile

* [Doruntime] fix xpu infer error

* [Doruntime] fix xpu infer error

* [XPU] update xpu dockerfile

* add xpu triton server docs

* add xpu triton server docs

* add xpu triton server docs

* add xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs
  • Loading branch information
DefTruth authored May 29, 2023
1 parent 3a99044 commit 434b48d
Show file tree
Hide file tree
Showing 10 changed files with 517 additions and 49 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ examples/vision/tests_quantize
fastdeploy/LICENSE
fastdeploy/ThirdPartyNotices.txt
FastDeployCSharp.cmake
python/fastdeploy/code_version.py
python/fastdeploy/code_version.py
*.pdmodel
*.pdiparams
*.pdiparams.info
log.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ input [
name: "inputs"
# input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING
data_type: TYPE_FP32
# input shape The batch dimension is omitted and the actual shape is [batch, c, h, w]
# input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w]
dims: [ 3, 224, 224 ]
}
]
Expand All @@ -31,6 +31,7 @@ instance_group [
count: 1
# Use GPU, CPU inference option is:KIND_CPU
kind: KIND_GPU
# kind: KIND_CPU
# The instance is deployed on the 0th GPU card
gpus: [0]
}
Expand Down Expand Up @@ -58,3 +59,32 @@ optimization {
}
]
}}
# instance_group [
# {
# # The number of instances is 1
# count: 1
# # Use GPU, CPU inference option is:KIND_CPU
# # kind: KIND_GPU
# kind: KIND_CPU
# # The instance is deployed on the 0th GPU card
# # gpus: [0]
# }
# ]
# optimization {
# execution_accelerators {
# cpu_execution_accelerator: [{
# name: "paddle_xpu",
# parameters { key: "cpu_threads" value: "4" }
# parameters { key: "use_paddle_log" value: "1" }
# parameters { key: "kunlunxin_id" value: "0" }
# parameters { key: "l3_workspace_size" value: "62914560" }
# parameters { key: "locked" value: "0" }
# parameters { key: "autotune" value: "1" }
# parameters { key: "precision" value: "int16" }
# parameters { key: "adaptive_seqlen" value: "0" }
# parameters { key: "enable_multi_stream" value: "0" }
# parameters { key: "gm_default_size" value: "0" }
# }]
# }}
10 changes: 9 additions & 1 deletion fastdeploy/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,15 @@ bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
}

bool Runtime::Infer() {
bool result = backend_->Infer(input_tensors_, &output_tensors_, false);
bool result = false;
if (option.device == Device::KUNLUNXIN) {
// FDTensor SetExternalData is not support for Device::KUNLUNXIN
// now, so, we need to set copy_to_fd as 'true'.
result = backend_->Infer(input_tensors_, &output_tensors_, true);
} else {
result = backend_->Infer(input_tensors_, &output_tensors_, false);
}

for (auto& tensor : output_tensors_) {
tensor.device_id = option.device_id;
}
Expand Down
1 change: 1 addition & 0 deletions scripts/patch_paddle_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def process_paddle_inference(paddle_inference_so_file):
rpaths = [
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
"$ORIGIN/../../third_party/install/mklml/lib/",
"$ORIGIN/../../third_party/install/xpu/lib/",
"$ORIGIN/../../../tensorrt/lib/"
]

Expand Down
47 changes: 47 additions & 0 deletions serving/Dockerfile_xpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG http_proxy
ARG https_proxy
ARG no_proxy

FROM paddlepaddle/fastdeploy:21.10-cpu-only-min

ENV TZ=Asia/Shanghai \
DEBIAN_FRONTEND=noninteractive \
http_proxy=$http_proxy \
https_proxy=$https_proxy \
no_proxy=$no_proxy

# Note: Here, use nightly built of paddle for xpu tritron server image
# to avoid the so confilcts between paddle and fastdeploy-python.
RUN apt-get update && apt-get install -y --no-install-recommends apt-utils libgomp1 ffmpeg libsm6 libxext6 vim wget \
&& python3 -m pip install -U pip \
&& python3 -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html \
&& python3 -m pip install paddlenlp fast-tokenizer-python

COPY python/dist/*.whl /opt/fastdeploy/
RUN python3 -m pip install /opt/fastdeploy/*.whl \
&& rm -rf /opt/fastdeploy/*.whl

COPY serving/build/libtriton_fastdeploy.so /opt/tritonserver/backends/fastdeploy/
COPY build/fastdeploy_install /opt/fastdeploy/
COPY benchmark/cpp /opt/fastdeploy/benchmark/cpp

RUN mv /opt/tritonserver/bin/tritonserver /opt/tritonserver/bin/fastdeployserver
ENV LD_LIBRARY_PATH="/opt/fastdeploy/lib:/opt/fastdeploy/third_libs/install/opencv/lib64:/opt/fastdeploy/third_libs/install/paddle2onnx/lib:/opt/fastdeploy/third_libs/install/paddle_inference/paddle/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mkldnn/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mklml/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/xpu/lib:$LD_LIBRARY_PATH"
# unset proxy
ENV http_proxy=
ENV https_proxy=
ENV no_proxy=
13 changes: 13 additions & 0 deletions serving/docs/zh_CN/compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ cd ../
docker build -t paddlepaddle/fastdeploy:x.y.z-ipu-only-21.10 -f serving/Dockerfile_ipu .
```

### 制作XPU镜像

```
# 进入serving目录执行脚本编译fastdeploy和服务化的backend
cd serving
bash scripts/build_fd_xpu.sh
# 退出到FastDeploy主目录,制作镜像
# x.y.z为FastDeploy版本号,可根据情况自己确定。比如: 1.0.6
cd ../
docker build -t paddlepaddle/fastdeploy:x.y.z-xpu-21.10 -f serving/Dockerfile_xpu .
```

## 非镜像方式编译

- [FastDeploy Serving CentOS编译教程](./compile_without_docker_centos.md)
31 changes: 31 additions & 0 deletions serving/docs/zh_CN/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,37 @@ optimization {
}
```

#### 配置使用Paddle+XPU引擎
```
optimization {
execution_accelerators {
# XPU推理配置通过CPU Execution启动, 配合KIND_CPU使用
cpu_execution_accelerator: [
{
name: "paddle_xpu",
# CPU相关配置
# cpu_threads: CPU计算线程数
# use_paddle_log: 开启paddle log信息
parameters { key: "cpu_threads" value: "4" }
parameters { key: "use_paddle_log" value: "0" }
# XPU相关配置
# kunlunxin_id: 使用的XPU卡号
# l3_workspace_size: L3缓存size
parameters { key: "kunlunxin_id" value: "0" }
parameters { key: "l3_workspace_size" value: "0xfffc00" }
parameters { key: "locked" value: "0" }
parameters { key: "autotune" value: "1" }
parameters { key: "precision" value: "int16" }
parameters { key: "adaptive_seqlen" value: "0" }
parameters { key: "enable_multi_stream" value: "0" }
parameters { key: "gm_default_size" value: "0" }
}
]
}
}
```


### 配置使用ONNXRuntime引擎
除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。ONNXRuntime引擎中,还可以进行如下配置,具体例子可参照[YOLOv5的Runtime配置](../../../examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt):

Expand Down
190 changes: 190 additions & 0 deletions serving/docs/zh_CN/xpu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# FastDeploy XPU Triton Server使用文档
FastDeploy XPU Triton Server通过Paddle Inference调用XPU进行推理,并且已经接入到 Triton Server。在FastDeploy XPU Triton Server中,使用XPU推理需要通过CPU instance_group和cpu_execution_accelerator进行配置和调用。本文档以PaddleClas为例,讲述如何把一个CPU/GPU Triton服务,改造成XPU Triton服务。

## 1. 准备服务化镜像

- 下载FastDeploy XPU Triton Server镜像
```bash
docker pull registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 # 稳定版
docker pull registry.baidubce.com/paddlepaddle/fastdeploy:0.0.0-xpu-21.10 # develop版本
```

- 下载部署示例代码
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/classification/paddleclas/serving

# 下载ResNet50_vd模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
tar -xvf ResNet50_vd_infer.tgz

# 将配置文件放入预处理目录
mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/inference_cls.yaml

# 将模型放入 models/runtime/1目录下, 并重命名为model.pdmodel和model.pdiparams
mv ResNet50_vd_infer/inference.pdmodel models/runtime/1/model.pdmodel
mv ResNet50_vd_infer/inference.pdiparams models/runtime/1/model.pdiparams
```

## 2. 启动容器
```bash
docker run -itd --name fd_xpu_server -v `pwd`/:/serving --net=host --privileged registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 /bin/bash
```

## 3. 验证XPU可用性
```bash
docker exec -it fd_xpu_server /bin/bash
cd /opt/fastdeploy/benchmark/cpp/build
./benchmark --model ResNet50_infer --config_path ../config/config.xpu.paddle.fp32.txt --enable_log_info
cd /serving
```
输出为:
```
I0529 11:07:46.860354 222 memory_optimize_pass.cc:222] Cluster name : batch_norm_46.tmp_2_max size: 1
--- Running analysis [ir_graph_to_program_pass]
I0529 11:07:46.889616 222 analysis_predictor.cc:1705] ======= optimize end =======
I0529 11:07:46.890262 222 naive_executor.cc:160] --- skip [feed], feed -> inputs
I0529 11:07:46.890703 222 naive_executor.cc:160] --- skip [save_infer_model/scale_0.tmp_1], fetch -> fetch
[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN.
[INFO] fastdeploy/runtime/backends/paddle/paddle_backend.cc(341)::Infer Running profiling for Runtime without H2D and D2H, Repeats: 1000, Warmup: 200
Runtime(ms): 0.706382ms.
```
显示启动的设备类型为:Device::KUNLUNXIN。FastDeploy Benchmark工具使用文档,请参考[benchmark](https://github.com/PaddlePaddle/FastDeploy/tree/develop/benchmark/cpp).

## 4. 配置Triton Model Config
```protobuf
# XPU服务化案例: examples/vision/classification/serving/models/runtime/config.pbtxt
# 将XPU部分的注释撤销,并注释掉原来的GPU设置,修改为:
# # Number of instances of the model
# instance_group [
# {
# # The number of instances is 1
# count: 1
# # Use GPU, CPU inference option is:KIND_CPU
# kind: KIND_GPU
# # kind: KIND_CPU
# # The instance is deployed on the 0th GPU card
# gpus: [0]
# }
# ]
# optimization {
# execution_accelerators {
# gpu_execution_accelerator : [ {
# # use TRT engine
# name: "tensorrt",
# # use fp16 on TRT engine
# parameters { key: "precision" value: "trt_fp16" }
# },
# {
# name: "min_shape"
# parameters { key: "inputs" value: "1 3 224 224" }
# },
# {
# name: "opt_shape"
# parameters { key: "inputs" value: "1 3 224 224" }
# },
# {
# name: "max_shape"
# parameters { key: "inputs" value: "16 3 224 224" }
# }
# ]
# }}
instance_group [
{
# The number of instances is 1
count: 1
# Use GPU, CPU inference option is:KIND_CPU
# kind: KIND_GPU
kind: KIND_CPU
# The instance is deployed on the 0th GPU card
# gpus: [0]
}
]
optimization {
execution_accelerators {
cpu_execution_accelerator: [{
name: "paddle_xpu",
parameters { key: "cpu_threads" value: "4" }
parameters { key: "use_paddle_log" value: "1" }
parameters { key: "kunlunxin_id" value: "0" }
parameters { key: "l3_workspace_size" value: "62914560" }
parameters { key: "locked" value: "0" }
parameters { key: "autotune" value: "1" }
parameters { key: "precision" value: "int16" }
parameters { key: "adaptive_seqlen" value: "0" }
parameters { key: "enable_multi_stream" value: "0" }
parameters { key: "gm_default_size" value: "0" }
}]
}}
```

## 5. 启动Triton服务
```bash
fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760
```
输出:
```
[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN.
.....
I0529 03:54:40.585326 385 server.cc:592]
+-------------+---------+--------+
| Model | Version | Status |
+-------------+---------+--------+
| paddlecls | 1 | READY |
| postprocess | 1 | READY |
| preprocess | 1 | READY |
| runtime | 1 | READY |
+-------------+---------+--------+
......
I0529 03:54:40.586430 385 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
I0529 03:54:40.586657 385 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
I0529 03:54:40.627382 385 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
```

## 6. 客户端请求
在物理机器中执行以下命令,发送grpc请求并输出结果:
```bash
# 下载测试图片
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg

# 安装客户端依赖
python3 -m pip install tritonclient\[all\]

# 发送请求
python3 paddlecls_grpc_client.py
```

发送请求成功后,会返回json格式的检测结果并打印输出:
```bash
output_name: CLAS_RESULT
{'label_ids': [153], 'scores': [0.6858349442481995]}
```
以上测试结果为Paddle Inference Backend + XPU R200下的输出。

## 7. 容器内自测
如果是想在容器内自测,则运行以下命令:
```bash
cd /serving
# 后台挂载
nohup fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760 > log.txt 2>&1 &
# 安装客户端依赖
python3 -m pip install tritonclient\[all\]
# 发送请求
unset http_proxy
unset https_proxy
python3 paddlecls_grpc_client.py
```

## 8. 配置修改

当前默认配置在XPU运行Paddle Inference引擎, 如果要在CPU/GPU其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](./model_configuration.md).

## 9. 常见问题
- [如何编写客户端 HTTP/GRPC 请求](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/client.md)
- [如何编译服务化部署镜像](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/compile.md)
- [服务化部署原理及动态Batch介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/demo.md)
- [模型仓库介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/model_repository.md)
Loading

0 comments on commit 434b48d

Please sign in to comment.