diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 55d692a8e3a..b0b90fab54d 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -252,7 +252,7 @@ OpSchema& OpSchema::NumOutputs(std::set allowed_output_nums) { return *this; } -OpSchema& OpSchema::ShapeInferenceFunction(InferenceFunction inferenceFunction) { +OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) { tensor_inference_function_ = inferenceFunction; return *this; } diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 4b1b317ac1b..5d34daac559 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -200,8 +200,8 @@ class OpSchema final { // // Note that signatures are defined to allow for forward-declaring // any structs used from ir.h - OpSchema& ShapeInferenceFunction(InferenceFunction inferenceFunction); - InferenceFunction GetShapeInferenceFunction() const { + OpSchema& TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction); + InferenceFunction GetTypeAndShapeInferenceFunction() const { return tensor_inference_function_; } diff --git a/onnx/defs/shape_inference.h b/onnx/defs/shape_inference.h index 03e1308f918..904f087d515 100644 --- a/onnx/defs/shape_inference.h +++ b/onnx/defs/shape_inference.h @@ -1,25 +1,26 @@ #pragma once -#include "onnx/proto_utils.h" #include "onnx/defs/data_type_utils.h" +#include "onnx/proto_utils.h" namespace ONNX_NAMESPACE { struct InferenceContext { virtual const AttributeProto* getAttribute(const std::string& name) const = 0; virtual size_t getNumInputs() const = 0; - virtual const TypeProto_Tensor* getInputType(size_t index) const = 0; + virtual const TypeProto* getInputType(size_t index) const = 0; virtual size_t getNumOutputs() const = 0; - virtual TypeProto_Tensor* getOutputType(size_t index) = 0; + virtual TypeProto* getOutputType(size_t index) = 0; virtual ~InferenceContext() {} }; typedef void (*InferenceFunction)(InferenceContext&); -template -inline bool getRepeatedAttribute(InferenceContext& ctx, - std::string attr_name, - std::vector& values) { +template +inline bool getRepeatedAttribute( + InferenceContext& ctx, + std::string attr_name, + std::vector& values) { const auto* attr = ctx.getAttribute(attr_name); if (attr) { values = RetrieveValues(*attr); @@ -27,10 +28,10 @@ inline bool getRepeatedAttribute(InferenceContext& ctx, } else { return false; } - } -inline bool hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::string& opname) { +inline bool +hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::string& opname) { if (static_cast(ctx.getNumInputs()) != n) { throw std::runtime_error(opname + " has wrong number of inputs"); } @@ -42,31 +43,72 @@ inline bool hasExactlyNInputTypes(InferenceContext& ctx, int n, const std::strin return true; } +inline void propagateElemTypeFromInputToOutput( + InferenceContext& ctx, + size_t inputIndex, + size_t outputIndex) { + auto input_type = ctx.getInputType(inputIndex); + if (nullptr == input_type || + input_type->value_case() != TypeProto::kTensorType) { + return; + } + auto output_type = ctx.getOutputType(outputIndex); + if (output_type->value_case() == TypeProto::kTensorType || + output_type->value_case() == TypeProto::VALUE_NOT_SET) { + output_type->mutable_tensor_type()->set_elem_type( + input_type->tensor_type().elem_type()); + } +} + inline bool hasNInputShapes(InferenceContext& ctx, int n) { if (static_cast(ctx.getNumInputs()) < n) { throw std::runtime_error("operator has too few inputs"); } for (int i = 0; i < n; i++) { - if (!ctx.getInputType(i) || !ctx.getInputType(i)->has_shape()) { + auto input_type = ctx.getInputType(i); + if (nullptr == input_type || !input_type->has_tensor_type() || + !input_type->tensor_type().has_shape()) { return false; } } return true; } -inline void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { - if (ctx.getInputType(inputIndex)) { - ctx.getOutputType(outputIndex)->set_elem_type(ctx.getInputType(inputIndex)->elem_type()); +inline void appendSingleDimCopiedFromInputTypeToOutputType( + InferenceContext& ctx, + size_t inputIndex, + size_t outputIndex, + size_t fromDimIndex) { + auto output_type = ctx.getOutputType(outputIndex); + auto input_type = ctx.getInputType(inputIndex); + if (TypeProto::kTensorType != output_type->value_case() || + TypeProto::kTensorType != input_type->value_case()) { + return; } + auto* dim = ctx.getOutputType(outputIndex) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim(); + *dim = input_type->tensor_type().shape().dim(static_cast(fromDimIndex)); } -inline void appendSingleDimCopiedFromInputTypeToOutputType(InferenceContext& ctx, size_t inputIndex, size_t outputIndex, size_t fromDimIndex) { - auto* dim = ctx.getOutputType(outputIndex)->mutable_shape()->add_dim(); - *dim = ctx.getInputType(inputIndex)->shape().dim(static_cast(fromDimIndex)); -} +inline void propagateShapeFromInputToOutput( + InferenceContext& ctx, + size_t inputIndex, + size_t outputIndex) { + auto output_type = ctx.getOutputType(outputIndex); + auto input_type = ctx.getInputType(inputIndex); + if (TypeProto::kTensorType != input_type->value_case() || + TypeProto::kTensorType != output_type->value_case()) { + throw std::runtime_error( + "zhangke: " + + std::to_string( + ctx.getInputType(inputIndex)->tensor_type().shape().dim_size())); + return; + } -inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) { - *ctx.getOutputType(outputIndex)->mutable_shape() = ctx.getInputType(inputIndex)->shape(); + *ctx.getOutputType(outputIndex)->mutable_tensor_type()->mutable_shape() = + ctx.getInputType(inputIndex)->tensor_type().shape(); } } // namespace ONNX_NAMESPACE diff --git a/onnx/defs/tensor/defs.cc b/onnx/defs/tensor/defs.cc index d29572a9a77..fd55dbf1273 100644 --- a/onnx/defs/tensor/defs.cc +++ b/onnx/defs/tensor/defs.cc @@ -57,14 +57,13 @@ NOTE: Casting to and from strings is not supported yet. "tensor(uint64)", "tensor(bool)"}, "Constrain output types. Casting to strings and complex are not supported.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - ctx.getOutputType(0)->set_elem_type( + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type( static_cast(ctx.getAttribute("to")->i())); - - if (!hasNInputShapes(ctx, 1)) { - return; - } - propagateShapeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + propagateShapeFromInputToOutput(ctx, 0, 0); }); ONNX_OPERATOR_SCHEMA(Reshape) @@ -84,8 +83,8 @@ from the input tensor).)DOC") "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); }); ONNX_OPERATOR_SCHEMA(Shape) @@ -111,18 +110,23 @@ Takes a tensor as input and outputs an 1D int64 tensor containing the shape of t "T1", {"tensor(int64)"}, "Constrains output to int64 tensor.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - ctx.getOutputType(0)->set_elem_type(TensorProto::INT64); - - if (!hasNInputShapes(ctx, 1)) { - return; - } - - if (ctx.getInputType(0)->has_shape()) { - ctx.getOutputType(0)->mutable_shape()->add_dim()->set_dim_value( - ctx.getInputType(0)->shape().dim_size()); - } - }); + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type( + TensorProto::INT64); + + if (!hasNInputShapes(ctx, 1)) { + return; + } + + if (ctx.getInputType(0)->tensor_type().has_shape()) { + ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() + ->set_dim_value( + ctx.getInputType(0)->tensor_type().shape().dim_size()); + } + }); ONNX_OPERATOR_SCHEMA(Size) .SetDoc(R"DOC( @@ -147,9 +151,10 @@ Takes a tensor as input and outputs a int64 scalar that equals to the total numb "T1", {"tensor(int64)"}, "Constrains output to int64 tensor, which should be a scalar though.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - ctx.getOutputType(0)->set_elem_type(TensorProto::INT64); - ctx.getOutputType(0)->mutable_shape(); + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type( + TensorProto::INT64); + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); }); ONNX_OPERATOR_SCHEMA(Concat) @@ -167,66 +172,68 @@ ONNX_OPERATOR_SCHEMA(Concat) "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain output types to float tensors.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + auto axisAttr = ctx.getAttribute("axis"); + if (!axisAttr) { + return; + } + int axis = static_cast(axisAttr->i()); + + bool found_exemplar = false; + TensorShapeProto shape_exemplar; + bool all_lengths_known = true; + int total_length = 0; + + for (size_t i = 0; i < ctx.getNumInputs(); i++) { + if (!ctx.getInputType(i)->tensor_type().has_shape()) { return; } - - auto axisAttr = ctx.getAttribute("axis"); - if (!axisAttr) { - return; - } - int axis = static_cast(axisAttr->i()); - - bool found_exemplar = false; - TensorShapeProto shape_exemplar; - bool all_lengths_known = true; - int total_length = 0; - - for (size_t i = 0; i < ctx.getNumInputs(); i++) { - if (!ctx.getInputType(i)->has_shape()) { - return; - } - const auto& shape = ctx.getInputType(i)->shape(); - if (found_exemplar) { - for (int j = 0; j < shape.dim_size(); j++) { - if (j == axis) { - if (shape.dim(j).has_dim_value()) { - total_length += static_cast(shape.dim(j).dim_value()); - } else { - all_lengths_known = false; - } + const auto& shape = ctx.getInputType(i)->tensor_type().shape(); + if (found_exemplar) { + for (int j = 0; j < shape.dim_size(); j++) { + if (j == axis) { + if (shape.dim(j).has_dim_value()) { + total_length += static_cast(shape.dim(j).dim_value()); } else { - if (shape.dim(j).has_dim_value() && - shape_exemplar.dim(j).has_dim_value() && - shape.dim(j).dim_value() != - shape_exemplar.dim(j).dim_value()) { - return; - } + all_lengths_known = false; + } + } else { + if (shape.dim(j).has_dim_value() && + shape_exemplar.dim(j).has_dim_value() && + shape.dim(j).dim_value() != + shape_exemplar.dim(j).dim_value()) { + return; } } - } else { - shape_exemplar = shape; - found_exemplar = true; } - } - - if (!found_exemplar) { - return; - } - - if (all_lengths_known) { - shape_exemplar.mutable_dim(axis)->set_dim_value(total_length); } else { - shape_exemplar.mutable_dim(axis)->set_dim_param(""); - } - - for (int i = 0; i < shape_exemplar.dim_size(); i++) { - *ctx.getOutputType(0)->mutable_shape()->add_dim() = - shape_exemplar.dim(i); + shape_exemplar = shape; + found_exemplar = true; } - }); + } + + if (!found_exemplar) { + return; + } + + if (all_lengths_known) { + shape_exemplar.mutable_dim(axis)->set_dim_value(total_length); + } else { + shape_exemplar.mutable_dim(axis)->set_dim_param(""); + } + + for (int i = 0; i < shape_exemplar.dim_size(); i++) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = shape_exemplar.dim(i); + } + }); ONNX_OPERATOR_SCHEMA(Split) .SinceVersion(2) @@ -251,36 +258,44 @@ ONNX_OPERATOR_SCHEMA(Split) 'axis'. Lengths of the parts can be specified using argument 'split'. Otherwise, the tensor is split to equal sized parts. )DOC") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasNInputShapes(ctx, 1)) { + return; + } + + auto axisAttr = ctx.getAttribute("axis"); + int axis = axisAttr ? static_cast(axisAttr->i()) : 0; + std::vector split; + if (!getRepeatedAttribute(ctx, "split", split)) { + if (!ctx.getInputType(0)->tensor_type().has_shape()) { return; } - - auto axisAttr = ctx.getAttribute("axis"); - int axis = axisAttr ? static_cast(axisAttr->i()) : 0; - std::vector split; - if (!getRepeatedAttribute(ctx, "split", split)) { - if (!ctx.getInputType(0)->has_shape()) { - return; - } - const auto& splitDim = ctx.getInputType(0)->shape().dim(axis); - if (!splitDim.has_dim_value()) { - return; - } - int splitDimValue = static_cast(splitDim.dim_value()); - int chunkSize = splitDimValue / static_cast(ctx.getNumOutputs()); - int leftOver = splitDimValue - (chunkSize * static_cast(ctx.getNumOutputs())); - for (int i = 0; i < static_cast(ctx.getNumOutputs()); i++) { - split.push_back(i < leftOver ? chunkSize + 1 : chunkSize); - } + const auto& splitDim = + ctx.getInputType(0)->tensor_type().shape().dim(axis); + if (!splitDim.has_dim_value()) { + return; + } + int splitDimValue = static_cast(splitDim.dim_value()); + int chunkSize = splitDimValue / static_cast(ctx.getNumOutputs()); + int leftOver = + splitDimValue - (chunkSize * static_cast(ctx.getNumOutputs())); + for (int i = 0; i < static_cast(ctx.getNumOutputs()); i++) { + split.push_back(i < leftOver ? chunkSize + 1 : chunkSize); } for (size_t i = 0; i < ctx.getNumOutputs(); i++) { - *ctx.getOutputType(i)->mutable_shape() = ctx.getInputType(0)->shape(); - ctx.getOutputType(i)->mutable_shape()->mutable_dim(axis)->set_dim_value(split[i]); + *ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() = + ctx.getInputType(0)->tensor_type().shape(); + ctx.getOutputType(i) + ->mutable_tensor_type() + ->mutable_shape() + ->mutable_dim(axis) + ->set_dim_value(split[i]); } - }); + } + }); ONNX_OPERATOR_SCHEMA(Slice) .SetDoc(R"DOC( @@ -356,24 +371,27 @@ will be (2, 1, 3). "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { - return; + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + std::vector perm; + if (!getRepeatedAttribute(ctx, "perm", perm)) { + for (int i = ctx.getInputType(0)->tensor_type().shape().dim_size() - 1; + i >= 0; + --i) { + perm.push_back(i); } + } - std::vector perm; - if (!getRepeatedAttribute(ctx, "perm", perm)) { - for (int i = ctx.getInputType(0)->shape().dim_size() - 1; i >= 0; --i) { - perm.push_back(i); - } - } - - propagateElemTypeFromInputToOutput(ctx, 0, 0); - for (size_t i = 0; i < perm.size(); ++i) { - appendSingleDimCopiedFromInputTypeToOutputType(ctx, 0, 0, static_cast(perm[i])); - } - }); + propagateElemTypeFromInputToOutput(ctx, 0, 0); + for (size_t i = 0; i < perm.size(); ++i) { + appendSingleDimCopiedFromInputTypeToOutputType( + ctx, 0, 0, static_cast(perm[i])); + } + }); ONNX_OPERATOR_SCHEMA(Gather) .SetDoc(R"DOC( @@ -439,24 +457,24 @@ Example 2: "Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 2)) { - return; - } - - int r = ctx.getInputType(0)->shape().dim_size(); - int q = ctx.getInputType(1)->shape().dim_size(); - - int out_rank = q + r - 1; - - if (out_rank == 0) { - ctx.getOutputType(0)->mutable_shape(); - } - for (int i = 0; i < out_rank; ++i) { - ctx.getOutputType(0)->mutable_shape()->add_dim(); - } - }); + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + return; + } + + int r = ctx.getInputType(0)->tensor_type().shape().dim_size(); + int q = ctx.getInputType(1)->tensor_type().shape().dim_size(); + + int out_rank = q + r - 1; + + if (out_rank == 0) { + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + } + for (int i = 0; i < out_rank; ++i) { + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim(); + } + }); ONNX_OPERATOR_SCHEMA(Squeeze) .Attr( @@ -473,31 +491,36 @@ Takes a parameter `axes` with a list of axes to squeeze. "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { - return; - } - - std::vector axes; - if (!getRepeatedAttribute(ctx, "axes", axes)) { - return; - } - - if (!ctx.getInputType(0)->has_shape()) { - return; - } - - ctx.getOutputType(0)->mutable_shape(); - - for (int i = 0, j = 0; i < ctx.getInputType(0)->shape().dim_size(); ++i) { - if (static_cast(j) < axes.size() && axes[j] == i) { - ++j; - } else { - *ctx.getOutputType(0)->mutable_shape()->add_dim() = ctx.getInputType(0)->shape().dim(i); - } + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + std::vector axes; + if (!getRepeatedAttribute(ctx, "axes", axes)) { + return; + } + + if (!ctx.getInputType(0)->tensor_type().has_shape()) { + return; + } + + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + for (int i = 0, j = 0; + i < ctx.getInputType(0)->tensor_type().shape().dim_size(); + ++i) { + if (static_cast(j) < axes.size() && axes[j] == i) { + ++j; + } else { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = ctx.getInputType(0)->tensor_type().shape().dim(i); } - }); + } + }); ONNX_OPERATOR_SCHEMA(Unsqueeze) .Attr( @@ -517,37 +540,53 @@ Dimension indices in `axes` are as seen in the output tensor. For example: "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.") - .ShapeInferenceFunction([](InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasNInputShapes(ctx, 1)) { - return; - } - - std::vector axes; - if (!getRepeatedAttribute(ctx, "axes", axes)) { - return; - } - std::sort(axes.begin(), axes.end()); - - if (!ctx.getInputType(0)->has_shape()) { - return; - } - - ctx.getOutputType(0)->mutable_shape(); - - int j = 0; - for (int i = 0; i < ctx.getInputType(0)->shape().dim_size(); ++i) { - while (static_cast(j) < axes.size() && axes[j] == ctx.getOutputType(0)->shape().dim_size()) { - ctx.getOutputType(0)->mutable_shape()->add_dim()->set_dim_value(1); - ++j; - } - *ctx.getOutputType(0)->mutable_shape()->add_dim() = ctx.getInputType(0)->shape().dim(i); - } - while (static_cast(j) < axes.size() && axes[j] == ctx.getOutputType(0)->shape().dim_size()) { - ctx.getOutputType(0)->mutable_shape()->add_dim()->set_dim_value(1); + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 1)) { + return; + } + + std::vector axes; + if (!getRepeatedAttribute(ctx, "axes", axes)) { + return; + } + std::sort(axes.begin(), axes.end()); + + if (!ctx.getInputType(0)->tensor_type().has_shape()) { + return; + } + + ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); + + int j = 0; + for (int i = 0; i < ctx.getInputType(0)->tensor_type().shape().dim_size(); + ++i) { + while (static_cast(j) < axes.size() && + axes[j] == + ctx.getOutputType(0)->tensor_type().shape().dim_size()) { + ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() + ->set_dim_value(1); ++j; } - }); + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = ctx.getInputType(0)->tensor_type().shape().dim(i); + } + while (static_cast(j) < axes.size() && + axes[j] == + ctx.getOutputType(0)->tensor_type().shape().dim_size()) { + ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() + ->set_dim_value(1); + ++j; + } + }); ONNX_OPERATOR_SCHEMA(Pad) .SinceVersion(2) diff --git a/onnx/shape_inference/implementation.h b/onnx/shape_inference/implementation.h index 4b7c8279a97..e84db4dd24c 100644 --- a/onnx/shape_inference/implementation.h +++ b/onnx/shape_inference/implementation.h @@ -1,13 +1,15 @@ #pragma once -#include "onnx/proto_utils.h" #include "onnx/defs/schema.h" +#include "onnx/proto_utils.h" -namespace ONNX_NAMESPACE { namespace shape_inference { +namespace ONNX_NAMESPACE { +namespace shape_inference { struct InferenceContextImpl : public InferenceContext { - InferenceContextImpl(const NodeProto & n, - const std::unordered_map& valueTypesByName) { + InferenceContextImpl( + const NodeProto& n, + const std::unordered_map& valueTypesByName) { for (const auto& attr : n.attribute()) { attributesByName_[attr.name()] = &attr; } @@ -35,46 +37,54 @@ struct InferenceContextImpl : public InferenceContext { size_t getNumInputs() const override { return allInputTypes_.size(); } - const TypeProto_Tensor* getInputType(size_t index) const override { + + const TypeProto* getInputType(size_t index) const override { if (index >= allInputTypes_.size()) { - throw std::runtime_error("input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds"); + throw std::runtime_error( + "input " + ONNX_NAMESPACE::to_string(index) + " is out of bounds"); } return allInputTypes_[index]; } + size_t getNumOutputs() const override { return allOutputTypes_.size(); } - TypeProto_Tensor* getOutputType(size_t index) override { + + TypeProto* getOutputType(size_t index) override { if (index >= allOutputTypes_.size()) { - throw std::runtime_error("output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds"); + throw std::runtime_error( + "output " + ONNX_NAMESPACE::to_string(index) + " is out of bounds"); } return &allOutputTypes_[index]; } - std::unordered_map attributesByName_; - std::vector allInputTypes_; - std::vector allOutputTypes_; + std::unordered_map attributesByName_; + std::vector allInputTypes_; + std::vector allOutputTypes_; }; - void mergeShapesAndTypes(const TypeProto_Tensor& inferredType, TypeProto_Tensor* existingType, const std::string& output) { -} +void mergeShapesAndTypes( + const TypeProto_Tensor& inferredType, + TypeProto_Tensor* existingType, + const std::string& output) {} void InferShapes(ModelProto& m) { std::unordered_map opset_imports; for (const auto& opset_import : m.opset_import()) { - opset_imports[opset_import.domain()] = static_cast(opset_import.version()); + opset_imports[opset_import.domain()] = + static_cast(opset_import.version()); } auto* g = m.mutable_graph(); - std::unordered_map valueTypesByName; + std::unordered_map valueTypesByName; for (auto& vi : *g->mutable_value_info()) { - valueTypesByName[vi.name()] = vi.mutable_type()->mutable_tensor_type(); + valueTypesByName[vi.name()] = vi.mutable_type(); } for (auto& vi : *g->mutable_input()) { - valueTypesByName[vi.name()] = vi.mutable_type()->mutable_tensor_type(); + valueTypesByName[vi.name()] = vi.mutable_type(); } for (auto& vi : *g->mutable_output()) { - valueTypesByName[vi.name()] = vi.mutable_type()->mutable_tensor_type(); + valueTypesByName[vi.name()] = vi.mutable_type(); } for (const auto& n : g->node()) { @@ -85,18 +95,18 @@ void InferShapes(ModelProto& m) { } auto domain_version = dit->second; - const auto schema = OpSchemaRegistry::Schema(n.op_type(), domain_version, n.domain()); + const auto schema = + OpSchemaRegistry::Schema(n.op_type(), domain_version, n.domain()); if (!schema) { continue; } InferenceContextImpl ctx(n, valueTypesByName); - - schema->GetShapeInferenceFunction()(ctx); + schema->GetTypeAndShapeInferenceFunction()(ctx); for (int i = 0; i < n.output_size(); ++i) { const auto& output = n.output(i); - const auto& inferredType = *ctx.getOutputType(i); + const auto& inferredType = ctx.getOutputType(i)->tensor_type(); // In this case, we have no new information, so don't bother // to add a contentless ValueInfo. @@ -108,13 +118,16 @@ void InferShapes(ModelProto& m) { // If there is already a ValueInfo associated with this // output, reuse it. Otherwise add a new one. auto iter = valueTypesByName.find(output); + TypeProto* type_proto = nullptr; TypeProto_Tensor* existingType = nullptr; if (iter != valueTypesByName.end()) { - existingType = iter->second; + type_proto = iter->second; + existingType = type_proto->mutable_tensor_type(); } else { auto vi = g->add_value_info(); vi->set_name(output); - existingType = vi->mutable_type()->mutable_tensor_type(); + type_proto = vi->mutable_type(); + existingType = type_proto->mutable_tensor_type(); } // Incorporate the inferred information. @@ -129,8 +142,10 @@ void InferShapes(ModelProto& m) { if (inferredType.has_shape()) { if (existingType->has_shape()) { - if (inferredType.shape().dim_size() != existingType->shape().dim_size()) { - throw std::runtime_error("inferred type and existing type are of different rank"); + if (inferredType.shape().dim_size() != + existingType->shape().dim_size()) { + throw std::runtime_error( + "inferred type and existing type are of different rank"); } } else { // make sure has_shape() == True for scalars @@ -146,8 +161,10 @@ void InferShapes(ModelProto& m) { auto* existingDim = existingType->mutable_shape()->mutable_dim(j); if (inferredDim.has_dim_value()) { auto inferredDimValue = inferredDim.dim_value(); - if (existingDim->has_dim_value() && existingDim->dim_value() != inferredDimValue) { - throw std::runtime_error("inferred dimension differs from existing dimension"); + if (existingDim->has_dim_value() && + existingDim->dim_value() != inferredDimValue) { + throw std::runtime_error( + "inferred dimension differs from existing dimension"); } *existingDim = inferredDim; } @@ -155,9 +172,10 @@ void InferShapes(ModelProto& m) { } // Make it available to futher inference. - valueTypesByName[output] = existingType; + valueTypesByName[output] = type_proto; } } } -}} +} // namespace shape_inference +} // namespace ONNX_NAMESPACE \ No newline at end of file