From 434b48dda5381e9bd5176d054c629aaecac59658 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 29 May 2023 14:38:25 +0800 Subject: [PATCH] [Serving] Support FastDeploy XPU Triton Server (#1994) * [patchelf] fix patchelf error for inference xpu * [serving] add xpu dockerfile and support fd server * [serving] add xpu dockerfile and support fd server * [Serving] support XPU + Tritron * [Serving] support XPU + Tritron * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] update xpu tritron docker file -> paddle 0.0.0 * [Dockerfile] add comments for xpu tritron dockerfile * [Doruntime] fix xpu infer error * [Doruntime] fix xpu infer error * [XPU] update xpu dockerfile * add xpu triton server docs * add xpu triton server docs * add xpu triton server docs * add xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs * update xpu triton server docs --- .gitignore | 6 +- .../serving/models/runtime/config.pbtxt | 32 ++- fastdeploy/runtime/runtime.cc | 10 +- scripts/patch_paddle_inference.py | 1 + serving/Dockerfile_xpu | 47 +++++ serving/docs/zh_CN/compile.md | 13 ++ serving/docs/zh_CN/model_configuration.md | 31 +++ serving/docs/zh_CN/xpu.md | 190 ++++++++++++++++++ serving/scripts/build_fd_xpu.sh | 65 ++++++ serving/src/fastdeploy_runtime.cc | 171 +++++++++++----- 10 files changed, 517 insertions(+), 49 deletions(-) create mode 100644 serving/Dockerfile_xpu create mode 100644 serving/docs/zh_CN/xpu.md create mode 100755 serving/scripts/build_fd_xpu.sh mode change 100755 => 100644 serving/src/fastdeploy_runtime.cc diff --git a/.gitignore b/.gitignore index 14b5cd3c4e..a34c0d5753 100644 --- a/.gitignore +++ b/.gitignore @@ -43,4 +43,8 @@ examples/vision/tests_quantize fastdeploy/LICENSE fastdeploy/ThirdPartyNotices.txt FastDeployCSharp.cmake -python/fastdeploy/code_version.py \ No newline at end of file +python/fastdeploy/code_version.py +*.pdmodel +*.pdiparams +*.pdiparams.info +log.txt \ No newline at end of file diff --git a/examples/vision/classification/paddleclas/serving/models/runtime/config.pbtxt b/examples/vision/classification/paddleclas/serving/models/runtime/config.pbtxt index db80445bc3..f3c56b785f 100644 --- a/examples/vision/classification/paddleclas/serving/models/runtime/config.pbtxt +++ b/examples/vision/classification/paddleclas/serving/models/runtime/config.pbtxt @@ -10,7 +10,7 @@ input [ name: "inputs" # input type such as TYPE_FP32、TYPE_UINT8、TYPE_INT8、TYPE_INT16、TYPE_INT32、TYPE_INT64、TYPE_FP16、TYPE_STRING data_type: TYPE_FP32 - # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] + # input shape, The batch dimension is omitted and the actual shape is [batch, c, h, w] dims: [ 3, 224, 224 ] } ] @@ -31,6 +31,7 @@ instance_group [ count: 1 # Use GPU, CPU inference option is:KIND_CPU kind: KIND_GPU + # kind: KIND_CPU # The instance is deployed on the 0th GPU card gpus: [0] } @@ -58,3 +59,32 @@ optimization { } ] }} + +# instance_group [ +# { +# # The number of instances is 1 +# count: 1 +# # Use GPU, CPU inference option is:KIND_CPU +# # kind: KIND_GPU +# kind: KIND_CPU +# # The instance is deployed on the 0th GPU card +# # gpus: [0] +# } +# ] + +# optimization { +# execution_accelerators { +# cpu_execution_accelerator: [{ +# name: "paddle_xpu", +# parameters { key: "cpu_threads" value: "4" } +# parameters { key: "use_paddle_log" value: "1" } +# parameters { key: "kunlunxin_id" value: "0" } +# parameters { key: "l3_workspace_size" value: "62914560" } +# parameters { key: "locked" value: "0" } +# parameters { key: "autotune" value: "1" } +# parameters { key: "precision" value: "int16" } +# parameters { key: "adaptive_seqlen" value: "0" } +# parameters { key: "enable_multi_stream" value: "0" } +# parameters { key: "gm_default_size" value: "0" } +# }] +# }} \ No newline at end of file diff --git a/fastdeploy/runtime/runtime.cc b/fastdeploy/runtime/runtime.cc index 72004d0daf..c597e06d16 100644 --- a/fastdeploy/runtime/runtime.cc +++ b/fastdeploy/runtime/runtime.cc @@ -204,7 +204,15 @@ bool Runtime::Infer(std::vector& input_tensors, } bool Runtime::Infer() { - bool result = backend_->Infer(input_tensors_, &output_tensors_, false); + bool result = false; + if (option.device == Device::KUNLUNXIN) { + // FDTensor SetExternalData is not support for Device::KUNLUNXIN + // now, so, we need to set copy_to_fd as 'true'. + result = backend_->Infer(input_tensors_, &output_tensors_, true); + } else { + result = backend_->Infer(input_tensors_, &output_tensors_, false); + } + for (auto& tensor : output_tensors_) { tensor.device_id = option.device_id; } diff --git a/scripts/patch_paddle_inference.py b/scripts/patch_paddle_inference.py index e85071ffde..d0b2647b9b 100644 --- a/scripts/patch_paddle_inference.py +++ b/scripts/patch_paddle_inference.py @@ -26,6 +26,7 @@ def process_paddle_inference(paddle_inference_so_file): rpaths = [ "$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/", "$ORIGIN/../../third_party/install/mklml/lib/", + "$ORIGIN/../../third_party/install/xpu/lib/", "$ORIGIN/../../../tensorrt/lib/" ] diff --git a/serving/Dockerfile_xpu b/serving/Dockerfile_xpu new file mode 100644 index 0000000000..b7fdc75b0c --- /dev/null +++ b/serving/Dockerfile_xpu @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG http_proxy +ARG https_proxy +ARG no_proxy + +FROM paddlepaddle/fastdeploy:21.10-cpu-only-min + +ENV TZ=Asia/Shanghai \ + DEBIAN_FRONTEND=noninteractive \ + http_proxy=$http_proxy \ + https_proxy=$https_proxy \ + no_proxy=$no_proxy + +# Note: Here, use nightly built of paddle for xpu tritron server image +# to avoid the so confilcts between paddle and fastdeploy-python. +RUN apt-get update && apt-get install -y --no-install-recommends apt-utils libgomp1 ffmpeg libsm6 libxext6 vim wget \ + && python3 -m pip install -U pip \ + && python3 -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html \ + && python3 -m pip install paddlenlp fast-tokenizer-python + +COPY python/dist/*.whl /opt/fastdeploy/ +RUN python3 -m pip install /opt/fastdeploy/*.whl \ + && rm -rf /opt/fastdeploy/*.whl + +COPY serving/build/libtriton_fastdeploy.so /opt/tritonserver/backends/fastdeploy/ +COPY build/fastdeploy_install /opt/fastdeploy/ +COPY benchmark/cpp /opt/fastdeploy/benchmark/cpp + +RUN mv /opt/tritonserver/bin/tritonserver /opt/tritonserver/bin/fastdeployserver +ENV LD_LIBRARY_PATH="/opt/fastdeploy/lib:/opt/fastdeploy/third_libs/install/opencv/lib64:/opt/fastdeploy/third_libs/install/paddle2onnx/lib:/opt/fastdeploy/third_libs/install/paddle_inference/paddle/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mkldnn/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mklml/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/xpu/lib:$LD_LIBRARY_PATH" +# unset proxy +ENV http_proxy= +ENV https_proxy= +ENV no_proxy= diff --git a/serving/docs/zh_CN/compile.md b/serving/docs/zh_CN/compile.md index 3cfaa08800..f2eb9180a9 100644 --- a/serving/docs/zh_CN/compile.md +++ b/serving/docs/zh_CN/compile.md @@ -65,6 +65,19 @@ cd ../ docker build -t paddlepaddle/fastdeploy:x.y.z-ipu-only-21.10 -f serving/Dockerfile_ipu . ``` +### 制作XPU镜像 + +``` +# 进入serving目录执行脚本编译fastdeploy和服务化的backend +cd serving +bash scripts/build_fd_xpu.sh + +# 退出到FastDeploy主目录,制作镜像 +# x.y.z为FastDeploy版本号,可根据情况自己确定。比如: 1.0.6 +cd ../ +docker build -t paddlepaddle/fastdeploy:x.y.z-xpu-21.10 -f serving/Dockerfile_xpu . +``` + ## 非镜像方式编译 - [FastDeploy Serving CentOS编译教程](./compile_without_docker_centos.md) diff --git a/serving/docs/zh_CN/model_configuration.md b/serving/docs/zh_CN/model_configuration.md index 03f8e09af0..56e4a6d321 100644 --- a/serving/docs/zh_CN/model_configuration.md +++ b/serving/docs/zh_CN/model_configuration.md @@ -112,6 +112,37 @@ optimization { } ``` +#### 配置使用Paddle+XPU引擎 +``` +optimization { + execution_accelerators { + # XPU推理配置通过CPU Execution启动, 配合KIND_CPU使用 + cpu_execution_accelerator: [ + { + name: "paddle_xpu", + # CPU相关配置 + # cpu_threads: CPU计算线程数 + # use_paddle_log: 开启paddle log信息 + parameters { key: "cpu_threads" value: "4" } + parameters { key: "use_paddle_log" value: "0" } + # XPU相关配置 + # kunlunxin_id: 使用的XPU卡号 + # l3_workspace_size: L3缓存size + parameters { key: "kunlunxin_id" value: "0" } + parameters { key: "l3_workspace_size" value: "0xfffc00" } + parameters { key: "locked" value: "0" } + parameters { key: "autotune" value: "1" } + parameters { key: "precision" value: "int16" } + parameters { key: "adaptive_seqlen" value: "0" } + parameters { key: "enable_multi_stream" value: "0" } + parameters { key: "gm_default_size" value: "0" } + } + ] + } +} +``` + + ### 配置使用ONNXRuntime引擎 除去配置 *Instance Groups*,决定模型运行在CPU还是GPU上。ONNXRuntime引擎中,还可以进行如下配置,具体例子可参照[YOLOv5的Runtime配置](../../../examples/vision/detection/yolov5/serving/models/runtime/config.pbtxt): diff --git a/serving/docs/zh_CN/xpu.md b/serving/docs/zh_CN/xpu.md new file mode 100644 index 0000000000..2f4b2982a0 --- /dev/null +++ b/serving/docs/zh_CN/xpu.md @@ -0,0 +1,190 @@ +# FastDeploy XPU Triton Server使用文档 +FastDeploy XPU Triton Server通过Paddle Inference调用XPU进行推理,并且已经接入到 Triton Server。在FastDeploy XPU Triton Server中,使用XPU推理需要通过CPU instance_group和cpu_execution_accelerator进行配置和调用。本文档以PaddleClas为例,讲述如何把一个CPU/GPU Triton服务,改造成XPU Triton服务。 + +## 1. 准备服务化镜像 + +- 下载FastDeploy XPU Triton Server镜像 +```bash +docker pull registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 # 稳定版 +docker pull registry.baidubce.com/paddlepaddle/fastdeploy:0.0.0-xpu-21.10 # develop版本 +``` + +- 下载部署示例代码 +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/classification/paddleclas/serving + +# 下载ResNet50_vd模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz +tar -xvf ResNet50_vd_infer.tgz + +# 将配置文件放入预处理目录 +mv ResNet50_vd_infer/inference_cls.yaml models/preprocess/1/inference_cls.yaml + +# 将模型放入 models/runtime/1目录下, 并重命名为model.pdmodel和model.pdiparams +mv ResNet50_vd_infer/inference.pdmodel models/runtime/1/model.pdmodel +mv ResNet50_vd_infer/inference.pdiparams models/runtime/1/model.pdiparams +``` + +## 2. 启动容器 +```bash +docker run -itd --name fd_xpu_server -v `pwd`/:/serving --net=host --privileged registry.baidubce.com/paddlepaddle/fastdeploy:1.0.7-xpu-21.10 /bin/bash +``` + +## 3. 验证XPU可用性 +```bash +docker exec -it fd_xpu_server /bin/bash +cd /opt/fastdeploy/benchmark/cpp/build +./benchmark --model ResNet50_infer --config_path ../config/config.xpu.paddle.fp32.txt --enable_log_info +cd /serving +``` +输出为: +``` +I0529 11:07:46.860354 222 memory_optimize_pass.cc:222] Cluster name : batch_norm_46.tmp_2_max size: 1 +--- Running analysis [ir_graph_to_program_pass] +I0529 11:07:46.889616 222 analysis_predictor.cc:1705] ======= optimize end ======= +I0529 11:07:46.890262 222 naive_executor.cc:160] --- skip [feed], feed -> inputs +I0529 11:07:46.890703 222 naive_executor.cc:160] --- skip [save_infer_model/scale_0.tmp_1], fetch -> fetch +[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN. +[INFO] fastdeploy/runtime/backends/paddle/paddle_backend.cc(341)::Infer Running profiling for Runtime without H2D and D2H, Repeats: 1000, Warmup: 200 +Runtime(ms): 0.706382ms. +``` +显示启动的设备类型为:Device::KUNLUNXIN。FastDeploy Benchmark工具使用文档,请参考[benchmark](https://github.com/PaddlePaddle/FastDeploy/tree/develop/benchmark/cpp). + +## 4. 配置Triton Model Config +```protobuf +# XPU服务化案例: examples/vision/classification/serving/models/runtime/config.pbtxt +# 将XPU部分的注释撤销,并注释掉原来的GPU设置,修改为: +# # Number of instances of the model +# instance_group [ +# { +# # The number of instances is 1 +# count: 1 +# # Use GPU, CPU inference option is:KIND_CPU +# kind: KIND_GPU +# # kind: KIND_CPU +# # The instance is deployed on the 0th GPU card +# gpus: [0] +# } +# ] + +# optimization { +# execution_accelerators { +# gpu_execution_accelerator : [ { +# # use TRT engine +# name: "tensorrt", +# # use fp16 on TRT engine +# parameters { key: "precision" value: "trt_fp16" } +# }, +# { +# name: "min_shape" +# parameters { key: "inputs" value: "1 3 224 224" } +# }, +# { +# name: "opt_shape" +# parameters { key: "inputs" value: "1 3 224 224" } +# }, +# { +# name: "max_shape" +# parameters { key: "inputs" value: "16 3 224 224" } +# } +# ] +# }} + +instance_group [ + { + # The number of instances is 1 + count: 1 + # Use GPU, CPU inference option is:KIND_CPU + # kind: KIND_GPU + kind: KIND_CPU + # The instance is deployed on the 0th GPU card + # gpus: [0] + } +] + +optimization { + execution_accelerators { + cpu_execution_accelerator: [{ + name: "paddle_xpu", + parameters { key: "cpu_threads" value: "4" } + parameters { key: "use_paddle_log" value: "1" } + parameters { key: "kunlunxin_id" value: "0" } + parameters { key: "l3_workspace_size" value: "62914560" } + parameters { key: "locked" value: "0" } + parameters { key: "autotune" value: "1" } + parameters { key: "precision" value: "int16" } + parameters { key: "adaptive_seqlen" value: "0" } + parameters { key: "enable_multi_stream" value: "0" } + parameters { key: "gm_default_size" value: "0" } + }] +}} +``` + +## 5. 启动Triton服务 +```bash +fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760 +``` +输出: +``` +[INFO] fastdeploy/runtime/runtime.cc(286)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::KUNLUNXIN. +..... +I0529 03:54:40.585326 385 server.cc:592] ++-------------+---------+--------+ +| Model | Version | Status | ++-------------+---------+--------+ +| paddlecls | 1 | READY | +| postprocess | 1 | READY | +| preprocess | 1 | READY | +| runtime | 1 | READY | ++-------------+---------+--------+ +...... +I0529 03:54:40.586430 385 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001 +I0529 03:54:40.586657 385 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000 +I0529 03:54:40.627382 385 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002 +``` + +## 6. 客户端请求 +在物理机器中执行以下命令,发送grpc请求并输出结果: +```bash +# 下载测试图片 +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# 安装客户端依赖 +python3 -m pip install tritonclient\[all\] + +# 发送请求 +python3 paddlecls_grpc_client.py +``` + +发送请求成功后,会返回json格式的检测结果并打印输出: +```bash +output_name: CLAS_RESULT +{'label_ids': [153], 'scores': [0.6858349442481995]} +``` +以上测试结果为Paddle Inference Backend + XPU R200下的输出。 + +## 7. 容器内自测 +如果是想在容器内自测,则运行以下命令: +```bash +cd /serving +# 后台挂载 +nohup fastdeployserver --model-repository=/serving/models --backend-config=python,shm-default-byte-size=10485760 > log.txt 2>&1 & +# 安装客户端依赖 +python3 -m pip install tritonclient\[all\] +# 发送请求 +unset http_proxy +unset https_proxy +python3 paddlecls_grpc_client.py +``` + +## 8. 配置修改 + +当前默认配置在XPU运行Paddle Inference引擎, 如果要在CPU/GPU其他推理引擎上运行。 需要修改`models/runtime/config.pbtxt`中配置,详情请参考[配置文档](./model_configuration.md). + +## 9. 常见问题 +- [如何编写客户端 HTTP/GRPC 请求](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/client.md) +- [如何编译服务化部署镜像](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/compile.md) +- [服务化部署原理及动态Batch介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/demo.md) +- [模型仓库介绍](https://github.com/PaddlePaddle/FastDeploy/blob/develop/serving/docs/zh_CN/model_repository.md) diff --git a/serving/scripts/build_fd_xpu.sh b/serving/scripts/build_fd_xpu.sh new file mode 100755 index 0000000000..4573128d4b --- /dev/null +++ b/serving/scripts/build_fd_xpu.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +echo "start build FD XPU library" + +docker run -i --rm --name build_fd_xpu \ + -v `pwd`/..:/workspace/fastdeploy \ + -e "http_proxy=${http_proxy}" \ + -e "https_proxy=${https_proxy}" \ + -e "no_proxy=${no_proxy}" \ + --network=host --privileged \ + paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \ + bash -c \ + 'export https_proxy_tmp=${https_proxy} + export http_proxy_tmp=${http_proxy} + cd /workspace/fastdeploy/python; + rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs; + ln -s /usr/bin/python3 /usr/bin/python; + export WITH_GPU=OFF; + export ENABLE_ORT_BACKEND=OFF; + export ENABLE_PADDLE_BACKEND=OFF; + export ENABLE_OPENVINO_BACKEND=OFF; + export ENABLE_VISION=ON; + export ENABLE_TEXT=ON; + unset http_proxy + unset https_proxy + python setup.py build; + python setup.py bdist_wheel; + cd /workspace/fastdeploy; + rm -rf build; mkdir build; cd build; + cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${PWD}/fastdeploy_install -DWITH_KUNLUNXIN=ON -DENABLE_PADDLE_BACKEND=ON -DENABLE_VISION=ON -DENABLE_BENCHMARK=ON -DLIBRARY_NAME=fastdeploy_runtime; + make -j`nproc`; + make install; + cd /workspace/fastdeploy/serving; + rm -rf build; mkdir build; cd build; + export https_proxy=${https_proxy_tmp} + export http_proxy=${http_proxy_tmp} + cmake .. -DTRITON_ENABLE_GPU=OFF -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy_install -DTRITON_COMMON_REPO_TAG=r21.10 -DTRITON_CORE_REPO_TAG=r21.10 -DTRITON_BACKEND_REPO_TAG=r21.10; + make -j`nproc`; + cd /workspace/fastdeploy/benchmark/cpp; + rm -rf build; mkdir build; cd build; + unset http_proxy + unset https_proxy + cmake .. -DFASTDEPLOY_INSTALL_DIR=/workspace/fastdeploy/build/fastdeploy_install; + make -j`nproc`; + wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_infer.tgz && tar -zxvf ResNet50_infer.tgz; + wget https://bj.bcebos.com/paddlehub/fastdeploy/000000014439.jpg; + rm -f ResNet50_infer.tgz; + rm -rf CMakeFiles; + ' + +echo "build FD XPU library done" diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc old mode 100755 new mode 100644 index b5f20602b5..0ecf6196de --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -199,6 +199,9 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) runtime_options_->UseOrtBackend(); } else if (name == "paddle") { runtime_options_->UsePaddleBackend(); + } else if (name == "paddle_xpu") { + // Note(qiuyanjun): use XPU via paddle inference backend. + runtime_options_->UsePaddleInferBackend(); } else if (name == "openvino") { runtime_options_->UseOpenVINOBackend(); } else if (name != "") { @@ -212,44 +215,118 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } triton::common::TritonJson::Value params; - if (ea.Find("parameters", ¶ms)) { - std::vector param_keys; - THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys)); - for (const auto& param_key : param_keys) { - std::string value_string; - THROW_IF_BACKEND_MODEL_ERROR( - params.MemberAsString(param_key.c_str(), &value_string)); - if (param_key == "cpu_threads") { - int cpu_thread_num; - THROW_IF_BACKEND_MODEL_ERROR( - ParseIntValue(value_string, &cpu_thread_num)); - runtime_options_->SetCpuThreadNum(cpu_thread_num); - } else if (param_key == "use_mkldnn") { - bool pd_enable_mkldnn; - THROW_IF_BACKEND_MODEL_ERROR( - ParseBoolValue(value_string, &pd_enable_mkldnn)); - runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); - } else if (param_key == "use_paddle_log") { - bool use_paddle_log; - THROW_IF_BACKEND_MODEL_ERROR( - ParseBoolValue(value_string, &use_paddle_log)); - runtime_options_->paddle_infer_option.enable_log_info = - use_paddle_log; - } else if (param_key == "num_streams") { - int num_streams; + if (name == "paddle_xpu") { + // parse parameters for cpu host + xpu device. + if (ea.Find("parameters", ¶ms)) { + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys)); + // default settings for XPU. + int kunlunxin_id = 0; + int l3_workspace_size = 0xfffc00; + bool locked = false; + bool autotune = true; + std::string autotune_file = ""; + std::string precision = "int16"; + bool adaptive_seqlen = false; + bool enable_multi_stream = false; + // for future use (only support lite backend now). + int gm_default_size = 0; + // common settings for cpu host. + int cpu_thread_num = -1; + bool use_paddle_log = false; + + for (const auto& param_key : param_keys) { + std::string value_string; THROW_IF_BACKEND_MODEL_ERROR( - ParseIntValue(value_string, &num_streams)); - runtime_options_->openvino_option.num_streams = num_streams; - } else if (param_key == "is_clone") { + params.MemberAsString(param_key.c_str(), &value_string)); + // parse common settings for cpu host. + if (param_key == "cpu_threads") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &cpu_thread_num)); + runtime_options_->SetCpuThreadNum(cpu_thread_num); + } else if (param_key == "use_paddle_log") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &use_paddle_log)); + runtime_options_->paddle_infer_option.enable_log_info = + use_paddle_log; + } else if (param_key == "is_clone") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &is_clone_)); + } else if (param_key == "encryption_key") { + runtime_options_->SetEncryptionKey(value_string); + // parse common settings for xpu device. + } else if (param_key == "kunlunxin_id") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &kunlunxin_id)); + } else if (param_key == "l3_workspace_size") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &l3_workspace_size)); + } else if (param_key == "locked") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &locked)); + } else if (param_key == "autotune") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &autotune)); + } else if (param_key == "precision") { + precision = value_string; + } else if (param_key == "adaptive_seqlen") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &adaptive_seqlen)); + } else if (param_key == "enable_multi_stream") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &enable_multi_stream)); + } else if (param_key == "gm_default_size") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &gm_default_size)); + } + } + // initialize xpu device settings + runtime_options_->UseKunlunXin( + kunlunxin_id, l3_workspace_size, locked, autotune, + autotune_file, precision, adaptive_seqlen, enable_multi_stream, + int64_t(gm_default_size)); + } + } else { + // parse parameters for cpu only + if (ea.Find("parameters", ¶ms)) { + std::vector param_keys; + THROW_IF_BACKEND_MODEL_ERROR(params.Members(¶m_keys)); + for (const auto& param_key : param_keys) { + std::string value_string; THROW_IF_BACKEND_MODEL_ERROR( - ParseBoolValue(value_string, &is_clone_)); - } else if (param_key == "use_ipu") { - // runtime_options_->UseIpu(); - } else if (param_key == "encryption_key") { - runtime_options_->SetEncryptionKey(value_string); + params.MemberAsString(param_key.c_str(), &value_string)); + if (param_key == "cpu_threads") { + int cpu_thread_num; + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &cpu_thread_num)); + runtime_options_->SetCpuThreadNum(cpu_thread_num); + } else if (param_key == "use_mkldnn") { + bool pd_enable_mkldnn; + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &pd_enable_mkldnn)); + runtime_options_->SetPaddleMKLDNN(pd_enable_mkldnn); + } else if (param_key == "use_paddle_log") { + bool use_paddle_log; + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &use_paddle_log)); + runtime_options_->paddle_infer_option.enable_log_info = + use_paddle_log; + } else if (param_key == "num_streams") { + int num_streams; + THROW_IF_BACKEND_MODEL_ERROR( + ParseIntValue(value_string, &num_streams)); + runtime_options_->openvino_option.num_streams = num_streams; + } else if (param_key == "is_clone") { + THROW_IF_BACKEND_MODEL_ERROR( + ParseBoolValue(value_string, &is_clone_)); + } else if (param_key == "use_ipu") { + // runtime_options_->UseIpu(); + } else if (param_key == "encryption_key") { + runtime_options_->SetEncryptionKey(value_string); + } } } - } + } // end 'name == "paddle_xpu"' } } } @@ -422,7 +499,7 @@ TRITONSERVER_Error* ModelState::LoadModel( } } - // GPU +// GPU #ifdef TRITON_ENABLE_GPU if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { @@ -432,8 +509,9 @@ TRITONSERVER_Error* ModelState::LoadModel( runtime_options_->UseCpu(); } #else - if (runtime_options_->device != fastdeploy::Device::IPU) { - // If Device is set to IPU, just skip CPU setting. + if ((runtime_options_->device != fastdeploy::Device::IPU) && + (runtime_options_->device != fastdeploy::Device::KUNLUNXIN)) { + // If Device is set to IPU/XPU, just skip CPU setting. runtime_options_->UseCpu(); } #endif // TRITON_ENABLE_GPU @@ -972,7 +1050,7 @@ void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests, SetInputTensors(total_batch_size, requests, request_count, &responses, &collector, &cuda_copy)); - // Wait for any in-flight input tensor copies to complete. +// Wait for any in-flight input tensor copies to complete. #ifdef TRITON_ENABLE_GPU if (cuda_copy) { cudaStreamSynchronize(CudaStream()); @@ -1146,15 +1224,16 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors( const uint32_t request_count, std::vector* responses) { // r22.12 - BackendOutputResponder responder( - requests, request_count, responses, - model_state_->TritonMemoryManager(), model_state_->MaxBatchSize() > 0, - model_state_->EnablePinnedOutput(), CudaStream()); - // r21.10 // BackendOutputResponder responder( - // requests, request_count, responses, StateForModel()->MaxBatchSize(), - // StateForModel()->TritonMemoryManager(), - // StateForModel()->EnablePinnedOutput(), CudaStream()); + // requests, request_count, responses, + // model_state_->TritonMemoryManager(), model_state_->MaxBatchSize() > 0, + // model_state_->EnablePinnedOutput(), CudaStream()); + + // r21.10 + BackendOutputResponder responder( + requests, request_count, responses, StateForModel()->MaxBatchSize(), + StateForModel()->TritonMemoryManager(), + StateForModel()->EnablePinnedOutput(), CudaStream()); // Use to hold string output contents bool cuda_copy = false;