Skip to content

Commit

Permalink
Increase protobuf size limit when parsing pb files (onnx#190)
Browse files Browse the repository at this point in the history
* Increase protobuf size limit when parsing pb files

* clang-format

* Separate out the python part from proto_utils
  • Loading branch information
bddppq authored and ezyang committed Nov 4, 2017
1 parent 6ac7a60 commit 3d4f101
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
15 changes: 8 additions & 7 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#include <pybind11/stl.h>
#include <unordered_map>

#include "onnx/defs/schema.h"
#include "onnx/checker.h"
#include "onnx/defs/schema.h"
#include "onnx/py_utils.h"

namespace onnx {

Expand Down Expand Up @@ -94,42 +95,42 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
checker.def("check_value_info", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<ValueInfoProto> proto(new ValueInfoProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_value_info(*proto, ir_version);
});

checker.def("check_tensor", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<TensorProto> proto(new TensorProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_tensor(*proto, ir_version);
});

checker.def("check_attribute", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<AttributeProto> proto(new AttributeProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_attribute(*proto, ir_version);
});

checker.def("check_node", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<NodeProto> proto(new NodeProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_node(*proto, ir_version);
});

checker.def("check_graph", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<GraphProto> proto(new GraphProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_graph(*proto, ir_version);
});

checker.def("check_model", [](const py::bytes& bytes,
int ir_version) -> void {
std::unique_ptr<ModelProto> proto(new ModelProto());
proto->ParseFromString(bytes);
ParseProtoFromPyBytes(proto.get(), bytes);
checker::check_model(*proto, ir_version);
});
}
Expand Down
17 changes: 17 additions & 0 deletions onnx/proto_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>

namespace onnx {

template <typename Proto>
bool ParseProtoFromBytes(Proto* proto, const char* buffer, size_t length) {
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
::google::protobuf::io::CodedInputStream coded_stream(
new google::protobuf::io::ArrayInputStream(buffer, length));
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
} // namespace onnx
18 changes: 18 additions & 0 deletions onnx/py_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <pybind11/pybind11.h>
#include "onnx/proto_utils.h"

namespace onnx {
namespace py = pybind11;

template <typename Proto>
bool ParseProtoFromPyBytes(Proto* proto, const py::bytes& bytes) {
// Get the buffer from Python bytes object
char* buffer = nullptr;
Py_ssize_t length;
PyBytes_AsStringAndSize(bytes.ptr(), &buffer, &length);

ParseProtoFromBytes(proto, buffer, length);
}
} // namespace onnx

0 comments on commit 3d4f101

Please sign in to comment.