Skip to content

Commit

Permalink
Kezhan/update size op output type (onnx#759)
Browse files Browse the repository at this point in the history
* change size output type form int64 to tensor(int64)

* fix size output type to be tensor type.
  • Loading branch information
linkerzhang authored Apr 14, 2018
1 parent 5355440 commit 97d3ae6
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 137 deletions.
4 changes: 2 additions & 2 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4368,8 +4368,8 @@ This version of the operator has been available since version 1 of the default O
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(bool)</dt>
<dd>Input tensor can be of arbitrary type.</dd>
<dt><tt>T1</tt> : int64</dt>
<dd>Constrains output to int64 scalar.</dd>
<dt><tt>T1</tt> : tensor(int64)</dt>
<dd>Constrains output to int64 tensor, which should be a scalar though.</dd>
</dl>

### <a name="Slice-1"></a>**Slice-1**</a>
Expand Down
4 changes: 2 additions & 2 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -6362,8 +6362,8 @@ This version of the operator has been available since version 1 of the default O
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(bool)</dt>
<dd>Input tensor can be of arbitrary type.</dd>
<dt><tt>T1</tt> : int64</dt>
<dd>Constrains output to int64 scalar.</dd>
<dt><tt>T1</tt> : tensor(int64)</dt>
<dd>Constrains output to int64 tensor, which should be a scalar though.</dd>
</dl>


Expand Down
276 changes: 143 additions & 133 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ NOTE: Casting to and from strings is not supported yet.
"tensor(bool)"},
"Constrain output types. Casting to strings and complex are not supported.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasExactlyNInputTypes(ctx, 1, "Cast")) {
return;
}
if (!hasExactlyNInputTypes(ctx, 1, "Cast")) {
return;
}

propagateShapeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 0);

auto type = ctx.getAttribute("to");
if (type) {
ctx.getOutputType(0)->set_elem_type(datatypeFromString(type->s()));
}
});
auto type = ctx.getAttribute("to");
if (type) {
ctx.getOutputType(0)->set_elem_type(datatypeFromString(type->s()));
}
});

ONNX_OPERATOR_SCHEMA(Reshape)
.SinceVersion(6)
Expand All @@ -89,12 +89,12 @@ from the input tensor).)DOC")
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasExactlyNInputTypes(ctx, 2, "Reshape")) {
return;
}
if (!hasExactlyNInputTypes(ctx, 2, "Reshape")) {
return;
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
});
propagateElemTypeFromInputToOutput(ctx, 0, 0);
});

ONNX_OPERATOR_SCHEMA(Shape)
.SetDoc(R"DOC(
Expand All @@ -120,17 +120,17 @@ Takes a tensor as input and outputs an 1D int64 tensor containing the shape of t
{"tensor(int64)"},
"Constrains output to int64 tensor.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasExactlyNInputTypes(ctx, 1, "Shape")) {
return;
}
if (!hasExactlyNInputTypes(ctx, 1, "Shape")) {
return;
}

ctx.getOutputType(0)->set_elem_type(TensorProto::INT64);
ctx.getOutputType(0)->set_elem_type(TensorProto::INT64);

if (ctx.getInputType(0)->has_shape()) {
ctx.getOutputType(0)->mutable_shape()->add_dim()->
set_dim_value(ctx.getInputType(0)->shape().dim_size());
}
});
if (ctx.getInputType(0)->has_shape()) {
ctx.getOutputType(0)->mutable_shape()->add_dim()->set_dim_value(
ctx.getInputType(0)->shape().dim_size());
}
});

ONNX_OPERATOR_SCHEMA(Size)
.SetDoc(R"DOC(
Expand All @@ -151,11 +151,14 @@ Takes a tensor as input and outputs a int64 scalar that equals to the total numb
"tensor(uint16)",
"tensor(bool)"},
"Input tensor can be of arbitrary type.")
.TypeConstraint("T1", {"int64"}, "Constrains output to int64 scalar.")
.TypeConstraint(
"T1",
{"tensor(int64)"},
"Constrains output to int64 tensor, which should be a scalar though.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
ctx.getOutputType(0)->set_elem_type(TensorProto::INT64);
ctx.getOutputType(0)->mutable_shape();
});
ctx.getOutputType(0)->set_elem_type(TensorProto::INT64);
ctx.getOutputType(0)->mutable_shape();
});

ONNX_OPERATOR_SCHEMA(Concat)
.SinceVersion(4)
Expand All @@ -173,64 +176,66 @@ ONNX_OPERATOR_SCHEMA(Concat)
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getNumInputs() == 0) {
return;
}
if (ctx.getNumInputs() == 0) {
return;
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 0);

auto axisAttr = ctx.getAttribute("axis");
if (!axisAttr) {
return;
}
int axis = axisAttr->i();
auto axisAttr = ctx.getAttribute("axis");
if (!axisAttr) {
return;
}
int axis = axisAttr->i();

bool found_exemplar = false;
TensorShapeProto shape_exemplar;
bool all_lengths_known = true;
int total_length = 0;
bool found_exemplar = false;
TensorShapeProto shape_exemplar;
bool all_lengths_known = true;
int total_length = 0;

for (int i = 0; i < ctx.getNumInputs(); i++) {
if (!ctx.getInputType(i)->has_shape()) {
return;
}
const auto& shape = ctx.getInputType(i)->shape();
if (found_exemplar) {
for (int j = 0; j < shape.dim_size(); j++) {
if (j == axis) {
if (shape.dim(j).has_dim_value()) {
total_length += shape.dim(j).dim_value();
} else {
all_lengths_known = false;
}
for (int i = 0; i < ctx.getNumInputs(); i++) {
if (!ctx.getInputType(i)->has_shape()) {
return;
}
const auto& shape = ctx.getInputType(i)->shape();
if (found_exemplar) {
for (int j = 0; j < shape.dim_size(); j++) {
if (j == axis) {
if (shape.dim(j).has_dim_value()) {
total_length += shape.dim(j).dim_value();
} else {
if (shape.dim(j).has_dim_value() &&
shape_exemplar.dim(j).has_dim_value()
&& shape.dim(j).dim_value() != shape_exemplar.dim(j).dim_value()) {
return;
}
all_lengths_known = false;
}
} else {
if (shape.dim(j).has_dim_value() &&
shape_exemplar.dim(j).has_dim_value() &&
shape.dim(j).dim_value() !=
shape_exemplar.dim(j).dim_value()) {
return;
}
}
} else {
shape_exemplar = shape;
found_exemplar = true;
}
} else {
shape_exemplar = shape;
found_exemplar = true;
}
}

if (!found_exemplar) {
return;
}
if (!found_exemplar) {
return;
}

if (all_lengths_known) {
shape_exemplar.mutable_dim(axis)->set_dim_value(total_length);
} else {
shape_exemplar.mutable_dim(axis)->set_dim_param("");
}
if (all_lengths_known) {
shape_exemplar.mutable_dim(axis)->set_dim_value(total_length);
} else {
shape_exemplar.mutable_dim(axis)->set_dim_param("");
}

for (int i = 0; i < shape_exemplar.dim_size(); i++) {
*ctx.getOutputType(0)->mutable_shape()->add_dim() = shape_exemplar.dim(i);
}
});
for (int i = 0; i < shape_exemplar.dim_size(); i++) {
*ctx.getOutputType(0)->mutable_shape()->add_dim() =
shape_exemplar.dim(i);
}
});

ONNX_OPERATOR_SCHEMA(Split)
.SinceVersion(2)
Expand All @@ -245,43 +250,48 @@ ONNX_OPERATOR_SCHEMA(Split)
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.Attr("axis", "Which axis to split on (defaults to 0)", AttributeProto::INT, static_cast<int64_t>(0))
.Attr(
"axis",
"Which axis to split on (defaults to 0)",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("split", "length of each output", AttributeProto::INTS, OPTIONAL)
.SetDoc(R"DOC(Split a tensor into a list of tensors, along the specified
'axis'. Lengths of the parts can be specified using argument 'split'.
Otherwise, the tensor is split to equal sized parts.
)DOC")
.ShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 0);

if (ctx.getNumOutputs() == 0) {
return;
}

if (ctx.getNumOutputs() == 0) {
auto axisAttr = ctx.getAttribute("axis");
int axis = axisAttr ? axisAttr->i() : 0;
std::vector<int64_t> split;
if (!getRepeatedAttribute(ctx, "split", split)) {
if (!ctx.getInputType(0)->has_shape()) {
return;
}

auto axisAttr = ctx.getAttribute("axis");
int axis = axisAttr ? axisAttr->i() : 0;
std::vector<int64_t> split;
if (!getRepeatedAttribute(ctx, "split", split)) {
if (!ctx.getInputType(0)->has_shape()) {
return;
}
const auto& splitDim = ctx.getInputType(0)->shape().dim(axis);
if (!splitDim.has_dim_value()) {
return;
}
int splitDimValue = splitDim.dim_value();
int chunkSize = splitDimValue / ctx.getNumOutputs();
int leftOver = splitDimValue - (chunkSize * ctx.getNumOutputs());
for (int i = 0; i < ctx.getNumOutputs(); i++) {
split.push_back(i < leftOver ? chunkSize + 1 : chunkSize);
}
const auto& splitDim = ctx.getInputType(0)->shape().dim(axis);
if (!splitDim.has_dim_value()) {
return;
}

int splitDimValue = splitDim.dim_value();
int chunkSize = splitDimValue / ctx.getNumOutputs();
int leftOver = splitDimValue - (chunkSize * ctx.getNumOutputs());
for (int i = 0; i < ctx.getNumOutputs(); i++) {
*ctx.getOutputType(i)->mutable_shape() = ctx.getInputType(0)->shape();
ctx.getOutputType(i)->mutable_shape()->mutable_dim(axis)->set_dim_value(split[i]);
split.push_back(i < leftOver ? chunkSize + 1 : chunkSize);
}
});
}

for (int i = 0; i < ctx.getNumOutputs(); i++) {
*ctx.getOutputType(i)->mutable_shape() = ctx.getInputType(0)->shape();
ctx.getOutputType(i)->mutable_shape()->mutable_dim(axis)->set_dim_value(
split[i]);
}
});

ONNX_OPERATOR_SCHEMA(Slice)
.SetDoc(R"DOC(
Expand Down Expand Up @@ -367,25 +377,25 @@ will be (2, 1, 3).
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasExactlyNInputTypes(ctx, 1, "Transpose")) {
return;
}
if (!ctx.getInputType(0)->has_shape()) {
return;
if (!hasExactlyNInputTypes(ctx, 1, "Transpose")) {
return;
}
if (!ctx.getInputType(0)->has_shape()) {
return;
}

std::vector<int64_t> perm;
if (!getRepeatedAttribute(ctx, "perm", perm)) {
for (int i = ctx.getInputType(0)->shape().dim_size() - 1; i >= 0; --i) {
perm.push_back(i);
}
}

std::vector<int64_t> perm;
if (!getRepeatedAttribute(ctx, "perm", perm)) {
for (int i = ctx.getInputType(0)->shape().dim_size() - 1; i >= 0; --i) {
perm.push_back(i);
}
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
for (size_t i = 0; i < perm.size(); ++i) {
appendSingleDimCopiedFromInputTypeToOutputType(ctx, 0, 0, perm[i]);
}
});
propagateElemTypeFromInputToOutput(ctx, 0, 0);
for (size_t i = 0; i < perm.size(); ++i) {
appendSingleDimCopiedFromInputTypeToOutputType(ctx, 0, 0, perm[i]);
}
});

ONNX_OPERATOR_SCHEMA(Gather)
.SetDoc(R"DOC(
Expand Down Expand Up @@ -454,29 +464,29 @@ Example 2:
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types")
.ShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasExactlyNInputTypes(ctx, 2, "Gather")) {
return;
}
if (!hasExactlyNInputTypes(ctx, 2, "Gather")) {
return;
}

propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 0);

if (!ctx.getInputType(0)->has_shape() ||
!ctx.getInputType(1)->has_shape()) {
return;
}
if (!ctx.getInputType(0)->has_shape() ||
!ctx.getInputType(1)->has_shape()) {
return;
}

int r = ctx.getInputType(0)->shape().dim_size();
int q = ctx.getInputType(1)->shape().dim_size();
int r = ctx.getInputType(0)->shape().dim_size();
int q = ctx.getInputType(1)->shape().dim_size();

int out_rank = q + r - 1;
int out_rank = q + r - 1;

if (out_rank == 0) {
ctx.getOutputType(0)->mutable_shape();
}
for (int i = 0; i < out_rank; ++i) {
ctx.getOutputType(0)->mutable_shape()->add_dim();
}
});
if (out_rank == 0) {
ctx.getOutputType(0)->mutable_shape();
}
for (int i = 0; i < out_rank; ++i) {
ctx.getOutputType(0)->mutable_shape()->add_dim();
}
});

ONNX_OPERATOR_SCHEMA(Squeeze)
.Attr(
Expand Down

0 comments on commit 97d3ae6

Please sign in to comment.