Skip to content

Commit

Permalink
change the inference context api to use TypeProto (onnx#779)
Browse files Browse the repository at this point in the history
* change the inference context api to use TypeProto instead of tensor type proto.

* refine the change.

* debug version

* fix test failure.
  • Loading branch information
linkerzhang authored Apr 20, 2018
1 parent 6953eff commit 7d1e102
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 239 deletions.
2 changes: 1 addition & 1 deletion onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) {
return *this;
}

OpSchema& OpSchema::ShapeInferenceFunction(InferenceFunction inferenceFunction) {
OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) {
tensor_inference_function_ = inferenceFunction;
return *this;
}
Expand Down
4 changes: 2 additions & 2 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class OpSchema final {
//
// Note that signatures are defined to allow for forward-declaring
// any structs used from ir.h
OpSchema& ShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction GetShapeInferenceFunction() const {
OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction GetTypeAndShapeInferenceFunction() const {
return tensor_inference_function_;
}

Expand Down
80 changes: 61 additions & 19 deletions onnx/defs/shape_inference.h
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
#pragma once

#include "onnx/proto_utils.h"
#include "onnx/defs/data_type_utils.h"
#include "onnx/proto_utils.h"

namespace ONNX_NAMESPACE {

struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
virtual const TypeProto_Tensor* getInputType(size_t index) const = 0;
virtual const TypeProto* getInputType(size_t index) const = 0;
virtual size_t getNumOutputs() const = 0;
virtual TypeProto_Tensor* getOutputType(size_t index) = 0;
virtual TypeProto* getOutputType(size_t index) = 0;
virtual ~InferenceContext() {}
};

typedef void (*InferenceFunction)(InferenceContext&);

template<typename T>
inline bool getRepeatedAttribute(InferenceContext& ctx,
std::string attr_name,
std::vector<T>& values) {
template <typename T>
inline bool getRepeatedAttribute(
InferenceContext& ctx,
std::string attr_name,
std::vector<T>& values) {
const auto* attr = ctx.getAttribute(attr_name);
if (attr) {
values = RetrieveValues<T>(*attr);
return true;
} else {
return false;
}

}

inline bool hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::string& opname) {
inline bool
hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::string& opname) {
if (static_cast<int>(ctx.getNumInputs()) != n) {
throw std::runtime_error(opname + " has wrong number of inputs");
}
Expand All @@ -42,31 +43,72 @@ inline bool hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::strin
return true;
}

inline void propagateElemTypeFromInputToOutput(
InferenceContext& ctx,
size_t inputIndex,
size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type ||
input_type->value_case() != TypeProto::kTensorType) {
return;
}
auto output_type = ctx.getOutputType(outputIndex);
if (output_type->value_case() == TypeProto::kTensorType ||
output_type->value_case() == TypeProto::VALUE_NOT_SET) {
output_type->mutable_tensor_type()->set_elem_type(
input_type->tensor_type().elem_type());
}
}

inline bool hasNInputShapes(InferenceContext& ctx, int n) {
if (static_cast<int>(ctx.getNumInputs()) < n) {
throw std::runtime_error("operator has too few inputs");
}
for (int i = 0; i < n; i++) {
if (!ctx.getInputType(i) || !ctx.getInputType(i)->has_shape()) {
auto input_type = ctx.getInputType(i);
if (nullptr == input_type || !input_type->has_tensor_type() ||
!input_type->tensor_type().has_shape()) {
return false;
}
}
return true;
}

inline void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
if (ctx.getInputType(inputIndex)) {
ctx.getOutputType(outputIndex)->set_elem_type(ctx.getInputType(inputIndex)->elem_type());
inline void appendSingleDimCopiedFromInputTypeToOutputType(
InferenceContext& ctx,
size_t inputIndex,
size_t outputIndex,
size_t fromDimIndex) {
auto output_type = ctx.getOutputType(outputIndex);
auto input_type = ctx.getInputType(inputIndex);
if (TypeProto::kTensorType != output_type->value_case() ||
TypeProto::kTensorType != input_type->value_case()) {
return;
}
auto* dim = ctx.getOutputType(outputIndex)
->mutable_tensor_type()
->mutable_shape()
->add_dim();
*dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
}

inline void appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext& ctx, size_t inputIndex, size_t outputIndex, size_t fromDimIndex) {
auto* dim = ctx.getOutputType(outputIndex)->mutable_shape()->add_dim();
*dim = ctx.getInputType(inputIndex)->shape().dim(static_cast<int>(fromDimIndex));
}
inline void propagateShapeFromInputToOutput(
InferenceContext& ctx,
size_t inputIndex,
size_t outputIndex) {
auto output_type = ctx.getOutputType(outputIndex);
auto input_type = ctx.getInputType(inputIndex);
if (TypeProto::kTensorType != input_type->value_case() ||
TypeProto::kTensorType != output_type->value_case()) {
throw std::runtime_error(
"zhangke: " +
std::to_string(
ctx.getInputType(inputIndex)->tensor_type().shape().dim_size()));
return;
}

inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
*ctx.getOutputType(outputIndex)->mutable_shape() = ctx.getInputType(inputIndex)->shape();
*ctx.getOutputType(outputIndex)->mutable_tensor_type()->mutable_shape() =
ctx.getInputType(inputIndex)->tensor_type().shape();
}

} // namespace ONNX_NAMESPACE
Loading

0 comments on commit 7d1e102

Please sign in to comment.