Skip to content

Commit

Permalink
misc fixes for issues found in ort integration (onnx#4681)
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu authored Dec 8, 2022
1 parent 37ba77e commit 45f508b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
4 changes: 4 additions & 0 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,10 @@ class OpSchemaRegistry final : public ISchemaRegistry {
auto& op_name = op_schema.Name();
auto& op_domain = op_schema.domain();
auto ver = op_schema.SinceVersion();
if (OpSchema::kUninitializedSinceVersion == ver) {
op_schema.SinceVersion(1);
ver = op_schema.SinceVersion();
}
// Stops because the opset_version is higher than opset_version_to_load
if (opset_version_to_load != 0 && ver > opset_version_to_load) {
return;
Expand Down
55 changes: 41 additions & 14 deletions onnx/defs/traditionalml/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,44 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
}
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
const auto input_ndim = input_shape.dim_size();

if (input_ndim == 1) {
return;
}
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
// This operator only applies to the last dimension; thus -1
for (int i = 0; i < input_ndim - 1; ++i) {
*output_shape->add_dim() = input_shape.dim(i);
}
// The length of second input is the length of the last dimension of the output

// value of the output's last dimension is the total amount of indices
// set Unknown length for the last dimension if it cannot be calculated
auto last_dim = output_shape->add_dim();
if (hasInputShape(ctx, 1)) {
const auto& indices_shape = getInputShape(ctx, 1);
if (indices_shape.dim_size() > 0) {
auto dim = indices_shape.dim(0);
*output_shape->add_dim() = dim;
return;
int64_t num_indices = 1;
std::string single_symbolic_dim;
for (int i = 0; i < indices_shape.dim_size(); i++) {
if (indices_shape.dim(i).has_dim_value()) {
num_indices *= indices_shape.dim(i).dim_value();
} else if (indices_shape.dim(i).has_dim_param()) {
if (single_symbolic_dim.empty()) {
// it is possible to set symbolic dimension param if the rest dim values are all value 1
single_symbolic_dim = indices_shape.dim(i).dim_param();
} else {
return;
}
} else {
return;
}
}
if (single_symbolic_dim.empty()) {
last_dim->set_dim_value(num_indices);
} else if (num_indices == 1) {
last_dim->set_dim_param(single_symbolic_dim);
}
}
}
// Unknown length of the last dimension
output_shape->add_dim();
})
.TypeConstraint(
"T",
Expand Down Expand Up @@ -851,9 +872,9 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
"Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
}

std::vector<std::string> label_strs;
auto result = getRepeatedAttribute(ctx, "classlabels_strings", label_strs);
bool using_strings = (result && !label_strs.empty());
std::vector<std::string> classlabels_strings;
auto result = getRepeatedAttribute(ctx, "classlabels_strings", classlabels_strings);
bool using_strings = (result && !classlabels_strings.empty());
if (using_strings) {
updateOutputElemType(ctx, 0, TensorProto::STRING);
} else {
Expand All @@ -864,10 +885,16 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
checkInputRank(ctx, 0, 2);
Dim N, E;
unifyInputDim(ctx, 0, 0, N);
std::vector<int64_t> class_ids;
auto has_ids = getRepeatedAttribute(ctx, "class_ids", class_ids);
if (has_ids) {
unifyDim(E, class_ids.size());

if (using_strings) {
unifyDim(E, classlabels_strings.size());
} else {
std::vector<int64_t> classlabels_int64s;
result = getRepeatedAttribute(ctx, "classlabels_int64s", classlabels_int64s);
if (!result || classlabels_int64s.empty()) {
fail_shape_inference("Non of classlabels_int64s or classlabels_strings is set.");
}
unifyDim(E, classlabels_int64s.size());
}
updateOutputShape(ctx, 0, {N});
updateOutputShape(ctx, 1, {N, E});
Expand Down
2 changes: 1 addition & 1 deletion onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8567,7 +8567,7 @@ def test_tree_ensemble_classifier(self) -> None:
"TreeEnsembleClassifier",
["x"],
["y", "z"],
class_ids=[0, 1, 2, 3, 4],
classlabels_int64s=[0, 1, 2, 3, 4],
domain=ONNX_ML_DOMAIN,
)
graph = self._make_graph(
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ max-line-length = 88
# type comments.
# E203 is need to support black formatting.
# https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html
# B905 `zip()` without an explicit `strict=` parameter.
# B950 as we have too many lines too long. This is ok because black handles most cases.
ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401, E203, B950
ignore = E127, E128, E265, E266, E402, E501, E722, F405, P207, P208, W503, F401, E203, B905, B950
exclude =
.git,
__pycache__,
Expand Down

0 comments on commit 45f508b

Please sign in to comment.