Skip to content

Commit

Permalink
Optional inputs, schema validation and docs prettification
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmytro Dzhulgakov committed Nov 3, 2017
1 parent 4b1ac03 commit 78ea478
Show file tree
Hide file tree
Showing 8 changed files with 1,745 additions and 1,185 deletions.
2,699 changes: 1,568 additions & 1,131 deletions docs/Operators.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ PYBIND11_PLUGIN(onnx_cpp2py_export) {
op_schema
.def_property_readonly("file", &OpSchema::file)
.def_property_readonly("line", &OpSchema::line)
.def_property_readonly("support_level", &OpSchema::support_level)
.def_property_readonly("support_level", &OpSchema::support_level)
.def_property_readonly(
"doc", &OpSchema::doc, py::return_value_policy::reference)
.def_property_readonly("min_input", &OpSchema::min_input)
.def_property_readonly("max_input", &OpSchema::max_input)
.def_property_readonly("min_output", &OpSchema::min_output)
.def_property_readonly("max_output", &OpSchema::max_output)
.def_property_readonly("optional_inputs", &OpSchema::optional_inputs)
.def_property_readonly("attributes", &OpSchema::attributes)
.def_property_readonly("input_desc", &OpSchema::input_desc)
.def_property_readonly("output_desc", &OpSchema::output_desc)
Expand Down
28 changes: 7 additions & 21 deletions onnx/defs/experiments/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using SupportType = onnx::OpSchema::SupportType;

OPERATOR_SCHEMA(ConstantFill)
.SetSupportLevel(SupportType::EXPERIMENTAL)
.NumInputs(0, 1)
.NumInputs(1)
.NumOutputs(1)
.AllowConsumed({{0, 0}})
.SetDoc(R"DOC(
Expand Down Expand Up @@ -55,7 +55,8 @@ NOTE: Currently, it supports data type of float, int32, int64, and bool.
"1D tensor containing the desired output shape. First input must be in "
"CPU context.",
AttrType::INT)
.Input(0, "input", "Input tensor (optional) to provide shape information.")
.Input(0, "input", "Input tensor (optional) to provide shape information.",
true /*optional*/)
.Output(
0,
"output",
Expand Down Expand Up @@ -233,6 +234,8 @@ OPERATOR_SCHEMA(Normalize)
.SetSupportLevel(SupportType::EXPERIMENTAL)
.NumInputs(1)
.NumOutputs(1)
.Input(0, "input", "Input matrix")
.Output(0, "output", "Matrix after normalization")
.SetDoc(R"DOC(
Given a matrix, apply L2-normalization along the last dimension.
)DOC");
Expand All @@ -241,6 +244,8 @@ OPERATOR_SCHEMA(Scale)
.SetSupportLevel(SupportType::EXPERIMENTAL)
.NumInputs(1)
.NumOutputs(1)
.Input(0, "input", "Input data to be scaled")
.Output(0, "output", "Output data after scaling")
.AllowConsumed({{0, 0}})
.SetDoc(R"DOC(
Scale takes one input data (Tensor<float>) and produces one output data
Expand All @@ -250,25 +255,6 @@ Scale takes one input data (Tensor<float>) and produces one output data
"(float, default 1.0) the scale to apply.",
AttrType::FLOAT);

OPERATOR_SCHEMA(RecurrentNetwork)
.SetSupportLevel(SupportType::EXPERIMENTAL)
.NumInputs(1, INT_MAX)
.NumOutputs(2, INT_MAX)
.SetDoc(R"DOC(
Run the input network in a recurrent fashion. This can be used to
implement fairly general recurrent neural networks (RNNs).
The operator proceeds as follows.
- First, initialized the states from the input recurrent states
- For each timestep T, apply the links (that map offsets from input/output
tensors into the inputs/outputs for the `step` network)
- Finally, alias the recurrent states to the specified output blobs.
This is a fairly special-case meta-operator, and so the implementation
is somewhat complex. It trades of generality (and frankly usability)
against performance and control (compared to e.g. TF
dynamic_rnn, Theano scan, etc).
See the usage examples for a flavor of how to use it.
)DOC");

OPERATOR_SCHEMA(GRUUnit)
.SetSupportLevel(SupportType::EXPERIMENTAL)
.NumInputs(4)
Expand Down
58 changes: 36 additions & 22 deletions onnx/defs/gen_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def display_number(v):
return str(v)


def display_attr_type(v):
assert isinstance(v, OpSchema.AttrType)
s = str(v)
s = s[s.rfind('.')+1:].lower()
if s[-1] == 's':
s = 'list of ' + s
return s


def support_level_str(level):
return \
"<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
Expand Down Expand Up @@ -58,41 +67,46 @@ def main(args):

# attributes
if schema.attributes:
s += ' * **attribute**:\n'
s += ' <dl>\n'
s += '\n#### Attributes\n\n'
s += '<dl>\n'
for _, attr in sorted(schema.attributes.items()):
s += ' <dt>{}</dt>\n'.format(attr.name)
s += ' <dd>{}</dd>\n'.format(attr.description)
s += ' </dl>\n'
s += '<dt><tt>{}</tt> : {}{}</dt>\n'.format(
attr.name,
display_attr_type(attr.type),
' (required)' if attr.required else '')
s += '<dd>{}</dd>\n'.format(attr.description)
s += '</dl>\n'


# inputs
s += ' * **input**:'
s += '\n#### Inputs'
if schema.min_input != schema.max_input:
s += '{} - {}'.format(display_number(schema.min_input),
display_number(schema.max_input))
s += '\n'
s += ' ({} - {})'.format(display_number(schema.min_input),
display_number(schema.max_input))
s += '\n\n'
if schema.input_desc:
s += ' <dl>\n'
for input_name, input_desc in schema.input_desc:
s += ' <dt>{}</dt>\n'.format(input_name)
s += ' <dd>{}</dd>\n'.format(input_desc)
s += ' </dl>\n'
s += '<dl>\n'
for idx, (input_name, input_desc) in enumerate(schema.input_desc):
s += '<dt><tt>{}</tt>{}</dt>\n'.format(
input_name,
' (optional)' if idx in schema.optional_inputs else '')
s += '<dd>{}</dd>\n'.format(input_desc)
s += '</dl>\n'

# outputs
s += ' * **output**:'
s += '\n#### Outputs'
if schema.min_output != schema.max_output:
s += '{} - {}'.format(display_number(schema.min_output),
display_number(schema.max_output))
s += ' ({} - {})'.format(display_number(schema.min_output),
display_number(schema.max_output))
s += '\n'
if schema.output_desc:
s += ' <dl>\n'
s += '<dl>\n'
for output_name, output_desc in schema.output_desc:
s += ' <dt>{}</dt>\n'.format(output_name)
s += ' <dd>{}</dd>\n'.format(output_desc)
s += ' </dl>\n'
s += '<dt><tt>{}</tt></dt>\n'.format(output_name)
s += '<dd>{}</dd>\n'.format(output_desc)
s += '</dl>\n'

s += '\n\n'
s += '\n\n---\n\n'
args.output.write(s)


Expand Down
58 changes: 57 additions & 1 deletion onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "schema.h"
#include <unordered_set>
#include <stdexcept>

namespace onnx {

Expand Down Expand Up @@ -48,6 +49,23 @@ bool OpSchema::Verify(const NodeProto& node) const {
}
}

// Check the values of inputs / outputs
for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
if (node.input(in_idx).empty() && !optional_inputs_.count(in_idx)) {
std::cerr
<< "Input " << in_idx
<< " is not marked optional but has an empty string in the graph";
return false;
}
}
for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
if (node.output(out_idx).empty()) {
std::cerr << "Output " << out_idx
<< " has an empty string in the graph";
return false;
}
}

// Check attributes
std::unordered_set<std::string> seen_attr_names{};
const AttributeProto * consume_attr = nullptr;
Expand Down Expand Up @@ -393,11 +411,14 @@ OpSchema& OpSchema::AllowUncheckedAttributes() {
return *this;
}

OpSchema& OpSchema::Input(const int n, const char* name, const char* description) {
OpSchema& OpSchema::Input(const int n, const char* name, const char* description, bool optional) {
if (int(input_desc_.size()) <= n) {
input_desc_.resize(n + 1);
}
input_desc_[n] = std::make_pair(name, description);
if (optional) {
optional_inputs_.insert(n);
}
return *this;
}

Expand Down Expand Up @@ -426,6 +447,41 @@ int OpSchema::CalculateOutput(int num_input) const {
}
}

void OpSchema::Finalize() {
#define ENFORCE(x) do { if (!(x)) throw std::logic_error("ONNX Schema " + name_ + ": failed validating the check: " + #x); } while (0)
ENFORCE(min_input_ <= max_input_);
ENFORCE(min_output_ <= max_output_);
ENFORCE(input_desc_.size() >= min_input_);
ENFORCE(output_desc_.size() >= min_output_);
ENFORCE(input_desc_.size() <= max_input_);
ENFORCE(output_desc_.size() <= max_output_);
// if max limit is finite - all names should be specified
if (max_input_ < std::numeric_limits<int>::max()) {
ENFORCE(input_desc_.size() == max_input_);
}
if (max_output_ < std::numeric_limits<int>::max()) {
ENFORCE(output_desc_.size() == max_output_);
}
// all inputs and outputs have names
for (const auto& it : input_desc_) {
ENFORCE(it.first);
}
for (const auto& it : output_desc_) {
ENFORCE(it.first);
}
// TODO: also cover checks for arbitrary number of inputs
// allow extra tailing inputs not be present if all inputs at the end are
// marked as optional
if (max_input_ < std::numeric_limits<int>::max()) {
int ind = max_input_;
while (ind > 0 && optional_inputs_.count(ind-1)) {
--ind;
}
min_input_ = std::min(min_input_, ind);
}
}


std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
if (!schema.attributes_.empty()) {
out << "Attributes:" << std::endl;
Expand Down
49 changes: 42 additions & 7 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class OpSchema {
EXPERIMENTAL, // This OP is experimental and can be changed or removed in the future.
};

OpSchema() : file_("unknown"), line_(0), support_(SupportType::COMMON) {}
OpSchema(const std::string& file, const int line)
: file_(file), line_(line), support_(SupportType::COMMON) {}
OpSchema() : name_("unknown"), file_("unknown"), line_(0), support_(SupportType::COMMON) {}
OpSchema(const std::string& name, const std::string& file, const int line)
: name_(name), file_(file), line_(line), support_(SupportType::COMMON) {}

/**
* @brief Returns the file that the op schema is registered from.
Expand Down Expand Up @@ -192,12 +192,21 @@ class OpSchema {
bool required = false);
OpSchema& AllowUncheckedAttributes();

OpSchema& Input(const int n, const char* name, const char* description);
OpSchema& Output(const int n, const char* name, const char* description);
// Optional = true means that the input might have empty input value
// (represented as "") in the graph even though the later inputs have values.
// It's useful for complex situation when there are several independent
// optional inputs.
OpSchema& Input(const int n, const char *name, const char *description,
bool optional = false);
OpSchema& Output(const int n, const char *name, const char *description);
// Calls the passed function with `this` as an argument. Useful for
// adding docs for temlated/macro ops.
OpSchema& FillUsing(std::function<void(OpSchema&)> populator);


// Verifies that the schema is valid and all specifications are compatible.
void Finalize();

/**
* @brief A function to allow one to get the number of outputs based on the
* number of inputs, if this schema supports it.
Expand All @@ -215,6 +224,9 @@ class OpSchema {
const std::vector<std::pair<const char*, const char*>>& output_desc() const {
return output_desc_;
}
const std::set<int> optional_inputs() const {
return optional_inputs_;
}
int min_input() const {
return min_input_;
}
Expand All @@ -232,6 +244,7 @@ class OpSchema {
}

private:
std::string name_;
std::string file_;
std::string doc_;
std::map<std::string, Attribute> attributes_{};
Expand All @@ -244,6 +257,7 @@ class OpSchema {
int max_input_ = std::numeric_limits<int>::max();
int min_output_ = 0;
int max_output_ = std::numeric_limits<int>::max();
std::set<int> optional_inputs_;
std::function<bool(int)> num_inputs_allowed_
= [](int) { return true; };
std::function<bool(int)> num_outputs_allowed_
Expand All @@ -257,6 +271,27 @@ class OpSchema {
= [](int){ return std::make_pair(UseType::DEFAULT, 0); };
};

/**
* Internal class used in schema declaration
*/
class OpSchemaHolder {
public:
OpSchemaHolder(OpSchema& schema) : schema_(&schema) {
// TODO: when we fix all issues - we can add abort() here
try {
schema.Finalize();
} catch (const std::exception& e) {
std::cerr << "Schema error: " << e.what() << std::endl;
}
}
const OpSchema* operator->() const {
return schema_;
}

private:
const OpSchema* schema_;
};

/**
* @brief A registry to hold all the operator schemas.
*/
Expand All @@ -273,7 +308,7 @@ class OpSchemaRegistry {
<< schema.file() << " line " << schema.line();
abort();
}
m.emplace(std::make_pair(key, OpSchema(file, line)));
m.emplace(std::make_pair(key, OpSchema(key, file, line)));
return m[key];
}

Expand Down Expand Up @@ -308,7 +343,7 @@ class OpSchemaRegistry {
};

#define OPERATOR_SCHEMA(name) \
static onnx::OpSchema& (op_schema_##name) = \
static onnx::OpSchemaHolder (op_schema_##name) = \
onnx::OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)

// Helper function
Expand Down
6 changes: 4 additions & 2 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ OPERATOR_SCHEMA(Reshape)
.AllowConsumed({{0, 0}})
.SetDoc(R"DOC(
Reshape the input tensor similar to numpy.reshape.
It takes a tensor as input and an argument `shape`. It outputs the reshaped tensor.
At most one dimension of the new shape can be -1. In this case, the value is
inferred from the size of the tensor and the remaining dimensions. A dimension
could also be 0, in which case the actual dimension value is going to be copied
Expand All @@ -54,13 +54,15 @@ OPERATOR_SCHEMA(Concat)
"Which axis to concat on",
AttrType::INT)
.SetDoc("Concatenate a list of tensors into a single tensor")
.Input(0, "inputs...", "List of tensors for concatenation")
.Output(0, "concat_result", "Concatenated tensor");

OPERATOR_SCHEMA(Split)
.NumInputs(1, 2)
.NumOutputs(1, INT_MAX)
.Input(0, "input", "The tensor to split")
.Input(1, "split", "Optional list of output lengths (see also arg 'split')")
.Output(0, "outputs...", "One or more outputs forming list of tensors after splitting")
.Attr("axis",
"Which axis to split on",
AttrType::INT)
Expand Down
Loading

0 comments on commit 78ea478

Please sign in to comment.