Skip to content

Commit

Permalink
[Serving] Support XPU encrypt & auth server (PaddlePaddle#2007)
Browse files Browse the repository at this point in the history
* [patchelf] fix patchelf error for inference xpu

* [serving] add xpu dockerfile and support fd server

* [serving] add xpu dockerfile and support fd server

* [Serving] support XPU + Tritron

* [Serving] support XPU + Tritron

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] update xpu tritron docker file -> paddle 0.0.0

* [Dockerfile] add comments for xpu tritron dockerfile

* [Doruntime] fix xpu infer error

* [Doruntime] fix xpu infer error

* [XPU] update xpu dockerfile

* add xpu triton server docs

* add xpu triton server docs

* add xpu triton server docs

* add xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* update xpu triton server docs

* [XPU] Update XPU L3 Cache setting docs

* [XPU] Add Encryption and AUTH support for XPU Server

* [XPU] Add Encryption and AUTH support for XPU Server

* [Bug Fix] fix paddle reader error

* [Serving] Support XPU encrypt & auth server

* [Serving] Support XPU encrypt & auth server

* [Serving] Support XPU encrypt & auth server

* [Serving] Support XPU encrypt & auth server

* [Triton] switch TAG 22.12 -> TAG 21.10wq

* update xpu auth server script
  • Loading branch information
DefTruth authored Jun 1, 2023
1 parent 387c569 commit 284b1b4
Show file tree
Hide file tree
Showing 13 changed files with 214 additions and 34 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ if(MSVC)
endif()

target_link_libraries(${LIBRARY_NAME} ${DEPEND_LIBS})
# Note(qiuyanjun): Currently, we need to manually link the whole
# leveldb static lib into fastdeploy lib if PADDLEINFERENCE_WITH_ENCRYPT_AUTH
# is 'ON'. Will remove this policy while the bug of paddle inference lib with
# auth & encrypt fixed.
if(ENABLE_PADDLE_BACKEND)
enable_paddle_encrypt_auth_link_policy(${LIBRARY_NAME})
endif()

if(ANDROID)
set_android_extra_libraries_target()
endif()
Expand Down
13 changes: 13 additions & 0 deletions FastDeploy.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ set(WITH_ANDROID_OPENMP @WITH_ANDROID_OPENMP@)
set(WITH_ANDROID_JAVA @WITH_ANDROID_JAVA@)
set(WITH_ANDROID_TENSOR_FUNCS @WITH_ANDROID_TENSOR_FUNCS@)

# encryption and auth
set(PADDLEINFERENCE_WITH_ENCRYPT_AUTH @PADDLEINFERENCE_WITH_ENCRYPT_AUTH@)

set(FASTDEPLOY_LIBS "")
set(FASTDEPLOY_INCS "")
list(APPEND FASTDEPLOY_INCS ${CMAKE_CURRENT_LIST_DIR}/include)
Expand Down Expand Up @@ -149,6 +152,16 @@ if(ENABLE_PADDLE_BACKEND)
else()
set(DNNL_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/mkldnn/lib/libmkldnn.so.0")
set(IOMP_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/mklml/lib/libiomp5.so")
if(PADDLEINFERENCE_WITH_ENCRYPT_AUTH)
set(FDMODEL_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/fdmodel/lib/libfastdeploy_wenxin.so")
set(FDMODEL_AUTH_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/fdmodel/lib/libfastdeploy_auth.so")
set(FDMODEL_MODEL_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/fdmodel/lib/libfastdeploy_model.so.2.0.0")
set(LEVELDB_LIB_DIR "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/leveldb/lib/")
list(APPEND FASTDEPLOY_LIBS ${FDMODEL_LIB} ${FDMODEL_AUTH_LIB} ${FDMODEL_MODEL_LIB})
# link_directories(LEVELDB_LIB_DIR)
# list(APPEND FASTDEPLOY_LIBS -lssl -lcrypto -lleveldb)
list(APPEND FASTDEPLOY_LIBS -lssl -lcrypto)
endif()
endif()
list(APPEND FASTDEPLOY_LIBS ${PADDLE_LIB})
if(EXISTS "${DNNL_LIB}")
Expand Down
1 change: 0 additions & 1 deletion benchmark/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/utils/gflags.cmake)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)

include_directories(${FASTDEPLOY_INCS})

add_executable(benchmark ${PROJECT_SOURCE_DIR}/benchmark.cc)
Expand Down
43 changes: 36 additions & 7 deletions cmake/paddle_inference.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@ if(WITH_GPU AND WITH_IPU)
message(FATAL_ERROR "Cannot build with WITH_GPU=ON and WITH_IPU=ON on the same time.")
endif()

# Custom options for Paddle Inference backend
option(PADDLEINFERENCE_DIRECTORY "Directory of custom Paddle Inference library" OFF)
option(PADDLEINFERENCE_WITH_ENCRYPT_AUTH "Whether the Paddle Inference is built with FD encryption and auth" OFF)

set(PADDLEINFERENCE_PROJECT "extern_paddle_inference")
set(PADDLEINFERENCE_PREFIX_DIR ${THIRD_PARTY_PATH}/paddle_inference)
set(PADDLEINFERENCE_SOURCE_DIR
${THIRD_PARTY_PATH}/paddle_inference/src/${PADDLEINFERENCE_PROJECT})
set(PADDLEINFERENCE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/paddle_inference)
# set(PADDLEINFERENCE_INC_DIR
# "${PADDLEINFERENCE_INSTALL_DIR}/paddle/include"
# CACHE PATH "paddle_inference include directory." FORCE)
# NOTE: The head path need by paddle inference is xxx/paddle_inference,
# not xxx/paddle_inference/paddle/include

set(PADDLEINFERENCE_INC_DIR "${PADDLEINFERENCE_INSTALL_DIR}"
CACHE PATH "paddle_inference include directory." FORCE)
set(PADDLEINFERENCE_LIB_DIR
Expand All @@ -41,7 +39,6 @@ set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}"
"${PADDLEINFERENCE_LIB_DIR}")

if(PADDLEINFERENCE_DIRECTORY)
# set(PADDLEINFERENCE_INC_DIR ${PADDLEINFERENCE_DIRECTORY}/paddle/include)
set(PADDLEINFERENCE_INC_DIR ${PADDLEINFERENCE_DIRECTORY})
endif()

Expand Down Expand Up @@ -70,9 +67,14 @@ else()
set(OMP_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mklml/lib/libiomp5.so")
set(P2O_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/paddle2onnx/lib/libpaddle2onnx.so")
set(ORT_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/onnxruntime/lib/libonnxruntime.so")
if(PADDLEINFERENCE_WITH_ENCRYPT_AUTH)
set(FDMODEL_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/fdmodel/lib/libfastdeploy_wenxin.so")
set(FDMODEL_AUTH_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/fdmodel/lib/libfastdeploy_auth.so")
set(FDMODEL_MODEL_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/fdmodel/lib/libfastdeploy_model.so.2.0.0")
set(LEVELDB_LIB_DIR "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/leveldb/lib")
endif()
endif(WIN32)


if(PADDLEINFERENCE_DIRECTORY)
# Use custom Paddle Inference libs.
if(EXISTS "${THIRD_PARTY_PATH}/install/paddle_inference")
Expand Down Expand Up @@ -194,3 +196,30 @@ add_library(external_omp STATIC IMPORTED GLOBAL)
set_property(TARGET external_omp PROPERTY IMPORTED_LOCATION
${OMP_LIB})
add_dependencies(external_omp ${PADDLEINFERENCE_PROJECT})

set(ENCRYPT_AUTH_LIBS )
if(PADDLEINFERENCE_WITH_ENCRYPT_AUTH)
add_library(external_fdmodel STATIC IMPORTED GLOBAL)
set_property(TARGET external_fdmodel PROPERTY IMPORTED_LOCATION
${FDMODEL_LIB})
add_library(external_fdmodel_auth STATIC IMPORTED GLOBAL)
set_property(TARGET external_fdmodel_auth PROPERTY IMPORTED_LOCATION
${FDMODEL_AUTH_LIB})
add_library(external_fdmodel_model STATIC IMPORTED GLOBAL)
set_property(TARGET external_fdmodel_model PROPERTY IMPORTED_LOCATION
${FDMODEL_MODEL_LIB})
add_dependencies(external_fdmodel ${PADDLEINFERENCE_PROJECT})
add_dependencies(external_fdmodel_auth ${PADDLEINFERENCE_PROJECT})
add_dependencies(external_fdmodel_model ${PADDLEINFERENCE_PROJECT})
list(APPEND ENCRYPT_AUTH_LIBS external_fdmodel external_fdmodel_auth external_fdmodel_model)
endif()

function(enable_paddle_encrypt_auth_link_policy LIBRARY_NAME)
if(ENABLE_PADDLE_BACKEND AND PADDLEINFERENCE_WITH_ENCRYPT_AUTH)
link_directories(${LEVELDB_LIB_DIR})
target_link_libraries(${LIBRARY_NAME} ${ENCRYPT_AUTH_LIBS} -lssl -lcrypto)
target_link_libraries(${LIBRARY_NAME} ${LEVELDB_LIB_DIR}/libleveldb.a)
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS
"-Wl,--whole-archive ${LEVELDB_LIB_DIR}/libleveldb.a -Wl,-no-whole-archive")
endif()
endfunction()
2 changes: 2 additions & 0 deletions fastdeploy/runtime/backends/paddle/option.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ struct PaddleBackendOption {
bool switch_ir_debug = false;
/// Whether enable ir optimize, default true
bool switch_ir_optimize = true;
/// Whether the load model is quantized model
bool is_quantize_model = false;

/*
* @brief IPU option, this will configure the IPU hardware, if inference model in IPU
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/runtime/backends/paddle/option_pybind.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void BindPaddleOption(pybind11::module& m) {
&PaddleBackendOption::mkldnn_cache_size)
.def_readwrite("gpu_mem_init_size",
&PaddleBackendOption::gpu_mem_init_size)
.def_readwrite("is_quantize_model",
&PaddleBackendOption::is_quantize_model)
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
.def("delete_pass", &PaddleBackendOption::DeletePass)
.def("set_ipu_config", &PaddleBackendOption::SetIpuConfig);
Expand Down
60 changes: 36 additions & 24 deletions fastdeploy/runtime/backends/paddle/paddle_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
FDASSERT(ReadBinaryFromFile(model, &model_content),
"Failed to read file %s.", model.c_str());
}
auto reader =
paddle2onnx::PaddleReader(model_content.c_str(), model_content.size());
// If it's a quantized model, and use cpu with mkldnn, automaticaly switch to
// int8 mode
if (reader.is_quantize_model) {

if (option.is_quantize_model) {
if (option.device == Device::GPU) {
FDWARNING << "The loaded model is a quantized model, while inference on "
"GPU, please use TensorRT backend to get better performance."
Expand Down Expand Up @@ -215,25 +212,6 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
}
}

inputs_desc_.resize(reader.num_inputs);
for (int i = 0; i < reader.num_inputs; ++i) {
std::string name(reader.inputs[i].name);
std::vector<int64_t> shape(reader.inputs[i].shape,
reader.inputs[i].shape + reader.inputs[i].rank);
inputs_desc_[i].name = name;
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype);
}
outputs_desc_.resize(reader.num_outputs);
for (int i = 0; i < reader.num_outputs; ++i) {
std::string name(reader.outputs[i].name);
std::vector<int64_t> shape(
reader.outputs[i].shape,
reader.outputs[i].shape + reader.outputs[i].rank);
outputs_desc_[i].name = name;
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
}
if (option.collect_trt_shape) {
// Set the shape info file.
std::string curr_model_dir = "./";
Expand Down Expand Up @@ -284,6 +262,40 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
}
}
predictor_ = paddle_infer::CreatePredictor(config_);

auto input_names = predictor_->GetInputNames();
auto output_names = predictor_->GetOutputNames();
auto input_dtypes = predictor_->GetInputTypes();
auto output_dtypes = predictor_->GetOutputTypes();
auto input_shapes = predictor_->GetInputTensorShape();
auto output_shapes = predictor_->GetOutputTensorShape();

inputs_desc_.resize(input_names.size());
for (int i = 0; i < input_names.size(); ++i) {
inputs_desc_[i].name = input_names[i];
auto iter = input_shapes.find(inputs_desc_[i].name);
FDASSERT(iter != input_shapes.end(), "Cannot find shape for input %s.",
inputs_desc_[i].name.c_str());
inputs_desc_[i].shape.assign(iter->second.begin(), iter->second.end());
auto iter1 = input_dtypes.find(inputs_desc_[i].name);
FDASSERT(iter1 != input_dtypes.end(), "Cannot find data type for input %s.",
inputs_desc_[i].name.c_str());
inputs_desc_[i].dtype = PaddleDataTypeToFD(iter1->second);
}
outputs_desc_.resize(output_names.size());
for (int i = 0; i < output_names.size(); ++i) {
outputs_desc_[i].name = output_names[i];
auto iter = output_shapes.find(outputs_desc_[i].name);
FDASSERT(iter != output_shapes.end(), "Cannot find shape for output %s.",
outputs_desc_[i].name.c_str());
outputs_desc_[i].shape.assign(iter->second.begin(), iter->second.end());
auto iter1 = output_dtypes.find(outputs_desc_[i].name);
FDASSERT(iter1 != output_dtypes.end(),
"Cannot find data type for output %s.",
outputs_desc_[i].name.c_str());
outputs_desc_[i].dtype = PaddleDataTypeToFD(iter1->second);
}

initialized_ = true;
return true;
}
Expand Down
1 change: 1 addition & 0 deletions scripts/patch_paddle_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def process_paddle_inference(paddle_inference_so_file):
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
"$ORIGIN/../../third_party/install/mklml/lib/",
"$ORIGIN/../../third_party/install/xpu/lib/",
"$ORIGIN/../../third_party/install/fdmodel/lib/",
"$ORIGIN/../../../tensorrt/lib/"
]

Expand Down
2 changes: 1 addition & 1 deletion serving/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ENV PATH=/home/cmake-3.18.6-Linux-x86_64/bin:$PATH


#install triton
ENV TAG=r22.12
ENV TAG=r21.10
RUN git clone https://github.com/triton-inference-server/server.git -b $TAG && \
cd server && \
mkdir -p build/tritonserver/install && \
Expand Down
48 changes: 48 additions & 0 deletions serving/Dockerfile_xpu_encrypt_auth
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

ARG http_proxy
ARG https_proxy
ARG no_proxy

FROM paddlepaddle/fastdeploy:21.10-cpu-only-min

ENV TZ=Asia/Shanghai \
DEBIAN_FRONTEND=noninteractive \
http_proxy=$http_proxy \
https_proxy=$https_proxy \
no_proxy=$no_proxy

# Note: Here, use nightly built of paddle for xpu tritron server image
# to avoid the so confilcts between paddle and fastdeploy-python.
RUN apt-get update && apt-get install -y --no-install-recommends apt-utils libgomp1 ffmpeg libsm6 libxext6 vim wget \
&& python3 -m pip install -U pip \
&& python3 -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html \
&& python3 -m pip install paddlenlp fast-tokenizer-python

COPY python/dist/*.whl /opt/fastdeploy/
RUN python3 -m pip install /opt/fastdeploy/*.whl \
&& rm -rf /opt/fastdeploy/*.whl

COPY serving/build/libtriton_fastdeploy.so /opt/tritonserver/backends/fastdeploy/
COPY build/fastdeploy_install /opt/fastdeploy/
# Fix the link error of libbkcl.so
COPY build/third_libs/install/paddle_inference/third_party/install/xpu/lib/libbkcl.so /home/users/yanzikui/wenxin/baidu/xpu/bkcl/output/so/libbkcl.so

RUN mv /opt/tritonserver/bin/tritonserver /opt/tritonserver/bin/fastdeployserver
ENV LD_LIBRARY_PATH="/opt/fastdeploy/lib:/opt/fastdeploy/third_libs/install/opencv/lib64:/opt/fastdeploy/third_libs/install/paddle2onnx/lib:/opt/fastdeploy/third_libs/install/paddle_inference/paddle/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mkldnn/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/mklml/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/xpu/lib:/opt/fastdeploy/third_libs/install/paddle_inference/third_party/install/fdmodel/lib:$LD_LIBRARY_PATH"
# unset proxy
ENV http_proxy=
ENV https_proxy=
ENV no_proxy=
2 changes: 1 addition & 1 deletion serving/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ nvidia-docker run -i --rm --name ${docker_name} \
rm -rf build; mkdir build; cd build;
export https_proxy=${https_proxy_tmp}
export http_proxy=${http_proxy_tmp}
cmake .. -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy_install -DTRITON_COMMON_REPO_TAG=r22.12 -DTRITON_CORE_REPO_TAG=r22.12 -DTRITON_BACKEND_REPO_TAG=r22.12;
cmake .. -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy_install -DTRITON_COMMON_REPO_TAG=r21.10 -DTRITON_CORE_REPO_TAG=r21.10 -DTRITON_BACKEND_REPO_TAG=r21.10;
make -j`nproc`'

echo "build FD GPU library done"
Expand Down
60 changes: 60 additions & 0 deletions serving/scripts/build_fd_xpu_encrypt_auth.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


echo "start build FD XPU AUTH library"

docker run -i --rm --name build_fd_xpu_auth_dev \
-v `pwd`/..:/workspace/fastdeploy \
-e "http_proxy=${http_proxy}" \
-e "https_proxy=${https_proxy}" \
-e "no_proxy=${no_proxy}" \
-e "PADDLEINFERENCE_URL=${PADDLEINFERENCE_URL}" \
--network=host --privileged \
paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \
bash -c \
'export https_proxy_tmp=${https_proxy}
export http_proxy_tmp=${http_proxy}
cd /workspace/fastdeploy/python;
rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
ln -s /usr/bin/python3 /usr/bin/python;
export WITH_GPU=OFF;
export ENABLE_ORT_BACKEND=OFF;
export ENABLE_PADDLE_BACKEND=OFF;
export ENABLE_OPENVINO_BACKEND=OFF;
export ENABLE_VISION=ON;
export ENABLE_TEXT=ON;
unset http_proxy
unset https_proxy
python setup.py build;
python setup.py bdist_wheel;
cd /workspace/fastdeploy;
rm -rf build; mkdir build; cd build;
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${PWD}/fastdeploy_install -DWITH_KUNLUNXIN=ON -DENABLE_PADDLE_BACKEND=ON -DPADDLEINFERENCE_URL=${PADDLEINFERENCE_URL} -DPADDLEINFERENCE_WITH_ENCRYPT_AUTH=ON -DENABLE_VISION=ON -DENABLE_BENCHMARK=ON -DLIBRARY_NAME=fastdeploy_runtime;
make -j`nproc`;
make install;
# fix the link error of libbkcl.so
mkdir -p /home/users/yanzikui/wenxin/baidu/xpu/bkcl/output/so;
cp /workspace/fastdeploy/build/fastdeploy_install/third_libs/install/paddle_inference/third_party/install/xpu/lib/libbkcl.so /home/users/yanzikui/wenxin/baidu/xpu/bkcl/output/so;
cd /workspace/fastdeploy/serving;
rm -rf build; mkdir build; cd build;
export https_proxy=${https_proxy_tmp}
export http_proxy=${http_proxy_tmp}
cmake .. -DTRITON_ENABLE_GPU=OFF -DFASTDEPLOY_DIR=/workspace/fastdeploy/build/fastdeploy_install -DTRITON_COMMON_REPO_TAG=r21.10 -DTRITON_CORE_REPO_TAG=r21.10 -DTRITON_BACKEND_REPO_TAG=r21.10;
make -j`nproc`;
echo $PADDLEINFERENCE_URL;
'

echo "build FD XPU AUTH library done"
6 changes: 6 additions & 0 deletions serving/src/fastdeploy_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
} else if (param_key == "is_clone") {
THROW_IF_BACKEND_MODEL_ERROR(
ParseBoolValue(value_string, &is_clone_));
} else if (param_key == "delete_passes") {
std::vector<std::string> delete_passes;
SplitStringByDelimiter(value_string, ' ', &delete_passes);
for (auto&& pass : delete_passes) {
runtime_options_->paddle_infer_option.DeletePass(pass);
}
} else if (param_key == "encryption_key") {
runtime_options_->SetEncryptionKey(value_string);
// parse common settings for xpu device.
Expand Down

0 comments on commit 284b1b4

Please sign in to comment.