Skip to content

Commit

Permalink
ORT support UINT8 and INT8 input and output
Browse files Browse the repository at this point in the history
  • Loading branch information
yunyaoXYY committed Mar 21, 2023
1 parent 3cc7276 commit b3e16e9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fastdeploy/runtime/backends/ort/ort_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
inputs_desc_.emplace_back(OrtValueInfo{input_name_ptr.get(), shape, data_type});
inputs_desc_.emplace_back(
OrtValueInfo{input_name_ptr.get(), shape, data_type});
}

size_t n_outputs = session_.GetOutputCount();
Expand All @@ -250,7 +251,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
outputs_desc_.emplace_back(OrtValueInfo{output_name_ptr.get(), shape, data_type});
outputs_desc_.emplace_back(
OrtValueInfo{output_name_ptr.get(), shape, data_type});

Ort::MemoryInfo out_memory_info("Cpu", OrtDeviceAllocator, 0,
OrtMemTypeDefault);
Expand Down Expand Up @@ -283,6 +285,12 @@ void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
dtype = FDDataType::FP16;
numel *= sizeof(float16);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
dtype = FDDataType::UINT8;
numel *= sizeof(uint8_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
dtype = FDDataType::INT8;
numel *= sizeof(int8_t);
} else {
FDASSERT(
false,
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/runtime/backends/ort/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype) {
return FDDataType::INT64;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return FDDataType::FP16;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
return FDDataType::UINT8;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
return FDDataType::INT8;
}
FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
return FDDataType::FP32;
Expand Down

0 comments on commit b3e16e9

Please sign in to comment.