Skip to content

Commit

Permalink
function proto for composite op. (onnx#802)
Browse files Browse the repository at this point in the history
* function for composite op.

* update checker logic to not allow main graph using attribute reference.

* update format

* format checker.cc

* function only contains attribute name too, same as input/output. type constraints should be inferred from function body.

* update proto files.

* add more comments for ref_attr_name field in AttributeProto
  • Loading branch information
linkerzhang authored Apr 26, 2018
1 parent cd58928 commit 485b787
Show file tree
Hide file tree
Showing 12 changed files with 319 additions and 48 deletions.
63 changes: 50 additions & 13 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "onnx/checker.h"
#include "onnx/defs/schema.h"
#include "onnx/proto_utils.h"
#include "onnx/string_utils.h"
#include "onnx/defs/schema.h"

#include <unordered_set>

Expand Down Expand Up @@ -57,14 +57,21 @@ void check_value_info(const ValueInfoProto& value_info, const CheckerContext&) {
} break;
#endif
default:
fail_check("Unrecognized type value case (value_info name: ", value_info.name(), "): ", value_case);
fail_check(
"Unrecognized type value case (value_info name: ",
value_info.name(),
"): ",
value_case);
}
}

void check_tensor(const TensorProto& tensor, const CheckerContext& /*ctx*/) {
enforce_has_field(tensor, data_type);
if (tensor.data_type() == TensorProto::UNDEFINED) {
fail_check("setting data_type field (tensor name: ", tensor.name(), ") to UNDEFINED is not allowed");
fail_check(
"setting data_type field (tensor name: ",
tensor.name(),
") to UNDEFINED is not allowed");
}

int num_value_fields = 0;
Expand All @@ -89,11 +96,17 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& /*ctx*/) {
#undef check_data_field

if (num_value_fields != 1) {
fail_check("TensorProto (tensor name: ", tensor.name(), ") should contain one and only one value field.");
fail_check(
"TensorProto (tensor name: ",
tensor.name(),
") should contain one and only one value field.");
}
if (has_raw_data) {
if (tensor.data_type() == TensorProto::STRING) {
fail_check("STRING data (tensor name: ", tensor.name(), ") should not be stored in raw_data field");
fail_check(
"STRING data (tensor name: ",
tensor.name(),
") should not be stored in raw_data field");
}
return;
} else {
Expand Down Expand Up @@ -140,7 +153,11 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& /*ctx*/) {
break;

default:
fail_check("Unrecognized data_type (tensor name: ", tensor.name(), "): ", tensor.data_type());
fail_check(
"Unrecognized data_type (tensor name: ",
tensor.name(),
"): ",
tensor.data_type());
}
}

Expand All @@ -161,9 +178,10 @@ void check_attribute(

int used_fields = 0;

#define check_type(expected_type) \
if (attr.has_type() && attr.type() != expected_type) { \
fail_check("type field and data field mismatch in attribute ", attr.name(), "."); \
#define check_type(expected_type) \
if (attr.has_type() && attr.type() != expected_type) { \
fail_check( \
"type field and data field mismatch in attribute ", attr.name(), "."); \
}

#define check_singular_field(field, type) \
Expand Down Expand Up @@ -193,8 +211,21 @@ void check_attribute(
#undef check_singular_field
#undef check_repeated_field

if (used_fields != 1) {
fail_check("Attribute (name: ", attr.name(), ") should contain one and only one value field.");
if (ctx.is_main_graph()) {
if (used_fields != 1) {
fail_check(
"Attribute (name: ",
attr.name(),
") should contain one and only one value field.");
}
} else {
// It's an attribute of a node in function body.
if (used_fields != 1 && (used_fields != 0 || !attr.has_ref_attr_name())) {
fail_check(
"Attribute (name: ",
attr.name(),
") should contain one value field or refer to attribute declared in function.");
}
}

if (attr.has_t()) {
Expand All @@ -220,7 +251,12 @@ void check_node(
enforce_non_empty_field(node, op_type);

if (node.input().empty() && node.output().empty()) {
fail_check("NodeProto (name: ", node.name(), ", type: ", node.op_type(), ") has zero input and zero output.");
fail_check(
"NodeProto (name: ",
node.name(),
", type: ",
node.op_type(),
") has zero input and zero output.");
}

// Resolve domain for node
Expand Down Expand Up @@ -345,7 +381,8 @@ void check_model(const ModelProto& model) {
ctx.set_ir_version(static_cast<int>(model.ir_version()));
std::unordered_map<std::string, int> opset_imports;
for (const auto& opset_import : model.opset_import()) {
opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
opset_imports[opset_import.domain()] =
static_cast<int>(opset_import.version());
}
if (model.ir_version() >= 3) {
if (opset_imports.empty())
Expand Down
25 changes: 17 additions & 8 deletions onnx/checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,32 @@ class ValidationError final : public std::runtime_error {
throw ONNX_NAMESPACE::checker::ValidationError(ONNX_NAMESPACE::MakeString(__VA_ARGS__));

class CheckerContext final {
int ir_version;
std::unordered_map<std::string, int> opset_imports;

public:
int get_ir_version() const {
return ir_version;
return ir_version_;
}
void set_ir_version(int v) {
ir_version = v;
ir_version_ = v;
}
const std::unordered_map<std::string, int>& get_opset_imports() const {
return opset_imports;
return opset_imports_;
}
void set_opset_imports(std::unordered_map<std::string, int> imps) {
opset_imports = std::move(imps);
opset_imports_ = std::move(imps);
}
bool is_main_graph() const {
return is_main_graph_;
}
void set_is_main_graph(bool is_main_graph) {
is_main_graph_ = is_main_graph;
}
explicit CheckerContext() : ir_version(-1) {}

explicit CheckerContext() : ir_version_(-1) {}

private:
int ir_version_;
std::unordered_map<std::string, int> opset_imports_;
bool is_main_graph_ = true;
};

struct LexicalScopeContext final {
Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx-ml.proto
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ message AttributeProto {

// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute

// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
optional string ref_attr_name = 21;

// A human-readable documentation for this attribute. Markdown is allowed.
optional string doc_string = 13;
Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx-ml.proto3
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ message AttributeProto {

// The name field MUST be present for this version of the IR.
string name = 1; // namespace Attribute

// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
string ref_attr_name = 21;

// A human-readable documentation for this attribute. Markdown is allowed.
string doc_string = 13;
Expand Down
49 changes: 44 additions & 5 deletions onnx/onnx-operators-ml.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,45 @@ import "onnx-ml.proto";
// that describes the ONNX standard operators.
//

// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}

message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
optional string name = 1;

// The first version of a function set which contains this function.
// When there's any breaking change for this function, the function set
// contains the function needs to bump its version, and since_version of
// the updated function will be changed to the updated function set version.
optional int64 since_version = 2;

// This field indicates whether the syntax, semantics, or presence
// of this function is in an experimental or stable stage. Once an
// function is published as STABLE, its syntax and semantics MUST NOT
// change in subsequent versions of the operator set.
// When a function is published as EXPERIMENTAL, the syntax and semantics
// of the function MAY change across operator set versions.
// Functions "become" stable by deprecating the experimental version and
// introducing a new stable function with the same name.
optional OperatorStatus status = 3;

// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;

// The attributes of the function.
repeated string attribute= 6;

// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
optional string doc_string = 8;
}

// An OperatorProto represents the immutable specification of the signature
// and semantics of an operator.
//
Expand All @@ -45,11 +84,7 @@ import "onnx-ml.proto";
// *since_version* is the version of the operator set that
// this operator was initially declared in.
//
message OperatorProto {
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message OperatorProto {

// The name of the operator within a domain.
// This field MUST be present in this version of the IR.
Expand Down Expand Up @@ -129,4 +164,8 @@ message OperatorSetProto {
// The operators specified by this operator set.
// The (name, version) MUST be unique across all OperatorProtos in operator
repeated OperatorProto operator = 8;

// The functions specified by this operator set.
// The (name, version) MUST be unique across all OperatorProtos/FunctionProtos in operator/functions
repeated FunctionProto functions = 5;
}
49 changes: 44 additions & 5 deletions onnx/onnx-operators-ml.proto3
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,45 @@ import "onnx-ml.proto3";
// that describes the ONNX standard operators.
//

// Operator/function status.
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}

message FunctionProto {
// The name of the function, similar usage of op_type in OperatorProto.
string name = 1;

// The first version of a function set which contains this function.
// When there's any breaking change for this function, the function set
// contains the function needs to bump its version, and since_version of
// the updated function will be changed to the updated function set version.
int64 since_version = 2;

// This field indicates whether the syntax, semantics, or presence
// of this function is in an experimental or stable stage. Once an
// function is published as STABLE, its syntax and semantics MUST NOT
// change in subsequent versions of the operator set.
// When a function is published as EXPERIMENTAL, the syntax and semantics
// of the function MAY change across operator set versions.
// Functions "become" stable by deprecating the experimental version and
// introducing a new stable function with the same name.
OperatorStatus status = 3;

// The inputs and outputs of the function.
repeated string input = 4;
repeated string output = 5;

// The attributes of the function.
repeated string attribute= 6;

// The nodes in the function.
repeated NodeProto node = 7;
// A human-readable documentation for this function. Markdown is allowed.
string doc_string = 8;
}

// An OperatorProto represents the immutable specification of the signature
// and semantics of an operator.
//
Expand All @@ -45,11 +84,7 @@ import "onnx-ml.proto3";
// *since_version* is the version of the operator set that
// this operator was initially declared in.
//
message OperatorProto {
enum OperatorStatus {
EXPERIMENTAL = 0;
STABLE = 1;
}
message OperatorProto {

// The name of the operator within a domain.
// This field MUST be present in this version of the IR.
Expand Down Expand Up @@ -129,4 +164,8 @@ message OperatorSetProto {
// The operators specified by this operator set.
// The (name, version) MUST be unique across all OperatorProtos in operator
repeated OperatorProto operator = 8;

// The functions specified by this operator set.
// The (name, version) MUST be unique across all OperatorProtos/FunctionProtos in operator/functions
repeated FunctionProto functions = 5;
}
Loading

0 comments on commit 485b787

Please sign in to comment.