Skip to content

Commit

Permalink
Protobuf typing (onnx#982)
Browse files Browse the repository at this point in the history
Add type stubs for google.protobuf

Don't add "type: ignore" to google.protobuf imports

Make the files generated by protoc-gen-mypy pass in mypy strict mode

Stop ignoring type errors in the generated files
  • Loading branch information
smessmer authored May 17, 2018
1 parent 321d874 commit ba86ec2
Show file tree
Hide file tree
Showing 23 changed files with 466 additions and 29 deletions.
4 changes: 2 additions & 2 deletions onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import onnx.checker # noqa
import onnx.defs # noqa

import google.protobuf.message # type: ignore
import google.protobuf.message

from typing import Union, Text, IO, Optional, cast, TypeVar, Any

Expand Down Expand Up @@ -49,7 +49,7 @@ def _serialize(proto): # type: (Union[bytes, google.protobuf.message.Message])
if isinstance(proto, bytes):
return proto
elif hasattr(proto, 'SerializeToString') and callable(proto.SerializeToString):
result = proto.SerializeToString() # type: bytes
result = proto.SerializeToString()
return result
else:
raise ValueError('No SerializeToString method is detected. '
Expand Down
2 changes: 1 addition & 1 deletion onnx/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
IR_VERSION)
import onnx.onnx_cpp2py_export.checker as C
import onnx.defs
from google.protobuf.message import Message # type: ignore
from google.protobuf.message import Message
from typing import TypeVar, Callable, Any, Type, cast


Expand Down
4 changes: 2 additions & 2 deletions onnx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numbers
from six import text_type, integer_types, binary_type

import google.protobuf.message # type: ignore
import google.protobuf.message
from onnx import TensorProto, AttributeProto, ValueInfoProto, TensorShapeProto, \
NodeProto, ModelProto, GraphProto, OperatorSetIdProto, TypeProto, IR_VERSION
import onnx.defs as defs
Expand Down Expand Up @@ -217,7 +217,7 @@ def make_attribute(
attr.ints.extend(int(v) for v in value)
attr.type = AttributeProto.INTS
elif all(byte_array):
attr.strings.extend(byte_array)
attr.strings.extend(cast(List[bytes], byte_array))
attr.type = AttributeProto.STRINGS
elif all(isinstance(v, TensorProto) for v in value):
attr.tensors.extend(value)
Expand Down
4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ no_implicit_optional = True
disallow_untyped_decorators = True
warn_unused_configs = True

# Ignore errors in setup.py and in generated protobuf python files
# Ignore errors in setup.py
[mypy-setup]
ignore_errors = True
[mypy-onnx.onnx*_pb2.*]
ignore_errors = True
4 changes: 4 additions & 0 deletions stubs/google/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Type stubs taken from https://github.com/python/typeshed/tree/f4d19d9f612ae24dfe3e35770d7820c3bd4045d2/third_party/2/google
Which unfortunately only defines it for python 2, but we also need it for python 3.

This stub can be deleted once we updated to a mypy version including https://github.com/python/typeshed/pull/2140
Empty file added stubs/google/__init__.pyi
Empty file.
1 change: 1 addition & 0 deletions stubs/google/protobuf/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = ... # type: str
Empty file.
161 changes: 161 additions & 0 deletions stubs/google/protobuf/descriptor.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Any

from .message import Message

class Error(Exception): ...
class TypeTransformationError(Error): ...

class DescriptorMetaclass(type):
def __instancecheck__(cls, obj): ...

class DescriptorBase:
__metaclass__ = DescriptorMetaclass
has_options = ... # type: Any
def __init__(self, options, options_class_name) -> None: ...
def GetOptions(self): ...

class _NestedDescriptorBase(DescriptorBase):
name = ... # type: Any
full_name = ... # type: Any
file = ... # type: Any
containing_type = ... # type: Any
def __init__(self, options, options_class_name, name, full_name, file, containing_type, serialized_start=..., serialized_end=...) -> None: ...
def GetTopLevelContainingType(self): ...
def CopyToProto(self, proto): ...

class Descriptor(_NestedDescriptorBase):
def __new__(cls, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=..., is_extendable=..., extension_ranges=..., oneofs=..., file=..., serialized_start=..., serialized_end=..., syntax=...): ...
fields = ... # type: Any
fields_by_number = ... # type: Any
fields_by_name = ... # type: Any
nested_types = ... # type: Any
nested_types_by_name = ... # type: Any
enum_types = ... # type: Any
enum_types_by_name = ... # type: Any
enum_values_by_name = ... # type: Any
extensions = ... # type: Any
extensions_by_name = ... # type: Any
is_extendable = ... # type: Any
extension_ranges = ... # type: Any
oneofs = ... # type: Any
oneofs_by_name = ... # type: Any
syntax = ... # type: Any
def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=..., is_extendable=..., extension_ranges=..., oneofs=..., file=..., serialized_start=..., serialized_end=..., syntax=...) -> None: ...
def EnumValueName(self, enum, value): ...
def CopyToProto(self, proto): ...

class FieldDescriptor(DescriptorBase):
TYPE_DOUBLE = ... # type: Any
TYPE_FLOAT = ... # type: Any
TYPE_INT64 = ... # type: Any
TYPE_UINT64 = ... # type: Any
TYPE_INT32 = ... # type: Any
TYPE_FIXED64 = ... # type: Any
TYPE_FIXED32 = ... # type: Any
TYPE_BOOL = ... # type: Any
TYPE_STRING = ... # type: Any
TYPE_GROUP = ... # type: Any
TYPE_MESSAGE = ... # type: Any
TYPE_BYTES = ... # type: Any
TYPE_UINT32 = ... # type: Any
TYPE_ENUM = ... # type: Any
TYPE_SFIXED32 = ... # type: Any
TYPE_SFIXED64 = ... # type: Any
TYPE_SINT32 = ... # type: Any
TYPE_SINT64 = ... # type: Any
MAX_TYPE = ... # type: Any
CPPTYPE_INT32 = ... # type: Any
CPPTYPE_INT64 = ... # type: Any
CPPTYPE_UINT32 = ... # type: Any
CPPTYPE_UINT64 = ... # type: Any
CPPTYPE_DOUBLE = ... # type: Any
CPPTYPE_FLOAT = ... # type: Any
CPPTYPE_BOOL = ... # type: Any
CPPTYPE_ENUM = ... # type: Any
CPPTYPE_STRING = ... # type: Any
CPPTYPE_MESSAGE = ... # type: Any
MAX_CPPTYPE = ... # type: Any
LABEL_OPTIONAL = ... # type: Any
LABEL_REQUIRED = ... # type: Any
LABEL_REPEATED = ... # type: Any
MAX_LABEL = ... # type: Any
MAX_FIELD_NUMBER = ... # type: Any
FIRST_RESERVED_FIELD_NUMBER = ... # type: Any
LAST_RESERVED_FIELD_NUMBER = ... # type: Any
def __new__(cls, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=..., file=..., has_default_value=..., containing_oneof=...): ...
name = ... # type: Any
full_name = ... # type: Any
index = ... # type: Any
number = ... # type: Any
type = ... # type: Any
cpp_type = ... # type: Any
label = ... # type: Any
has_default_value = ... # type: Any
default_value = ... # type: Any
containing_type = ... # type: Any
message_type = ... # type: Any
enum_type = ... # type: Any
is_extension = ... # type: Any
extension_scope = ... # type: Any
containing_oneof = ... # type: Any
def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=..., file=..., has_default_value=..., containing_oneof=...) -> None: ...
@staticmethod
def ProtoTypeToCppProtoType(proto_type): ...

class EnumDescriptor(_NestedDescriptorBase):
def __new__(cls, name, full_name, filename, values, containing_type=..., options=..., file=..., serialized_start=..., serialized_end=...): ...
values = ... # type: Any
values_by_name = ... # type: Any
values_by_number = ... # type: Any
def __init__(self, name, full_name, filename, values, containing_type=..., options=..., file=..., serialized_start=..., serialized_end=...) -> None: ...
def CopyToProto(self, proto): ...

class EnumValueDescriptor(DescriptorBase):
def __new__(cls, name, index, number, type=..., options=...): ...
name = ... # type: Any
index = ... # type: Any
number = ... # type: Any
type = ... # type: Any
def __init__(self, name, index, number, type=..., options=...) -> None: ...

class OneofDescriptor:
def __new__(cls, name, full_name, index, containing_type, fields): ...
name = ... # type: Any
full_name = ... # type: Any
index = ... # type: Any
containing_type = ... # type: Any
fields = ... # type: Any
def __init__(self, name, full_name, index, containing_type, fields) -> None: ...

class ServiceDescriptor(_NestedDescriptorBase):
index = ... # type: Any
methods = ... # type: Any
def __init__(self, name, full_name, index, methods, options=..., file=..., serialized_start=..., serialized_end=...) -> None: ...
def FindMethodByName(self, name): ...
def CopyToProto(self, proto): ...

class MethodDescriptor(DescriptorBase):
name = ... # type: Any
full_name = ... # type: Any
index = ... # type: Any
containing_service = ... # type: Any
input_type = ... # type: Any
output_type = ... # type: Any
def __init__(self, name, full_name, index, containing_service, input_type, output_type, options=...) -> None: ...

class FileDescriptor(DescriptorBase):
def __new__(cls, name, package, options=..., serialized_pb=..., dependencies=..., syntax=...): ...
_options = ... # type: Any
message_types_by_name = ... # type: Any
name = ... # type: Any
package = ... # type: Any
syntax = ... # type: Any
serialized_pb = ... # type: Any
enum_types_by_name = ... # type: Any
extensions_by_name = ... # type: Any
dependencies = ... # type: Any
def __init__(self, name, package, options=..., serialized_pb=..., dependencies=..., syntax=...) -> None: ...
def CopyToProto(self, proto): ...

def MakeDescriptor(desc_proto, package=..., build_file_if_cpp=..., syntax=...): ...
def _ParseOptions(message: Message, string: str) -> Message: ...
18 changes: 18 additions & 0 deletions stubs/google/protobuf/descriptor_pool.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, Optional

class DescriptorPool:
def __new__(cls, descriptor_db: Optional[Any] = ...): ...
def __init__(self, descriptor_db: Optional[Any] = ...) -> None: ...
def Add(self, file_desc_proto): ...
def AddSerializedFile(self, serialized_file_desc_proto): ...
def AddDescriptor(self, desc): ...
def AddEnumDescriptor(self, enum_desc): ...
def AddFileDescriptor(self, file_desc): ...
def FindFileByName(self, file_name): ...
def FindFileContainingSymbol(self, symbol): ...
def FindMessageTypeByName(self, full_name): ...
def FindEnumTypeByName(self, full_name): ...
def FindFieldByName(self, full_name): ...
def FindExtensionByName(self, full_name): ...

def Default(): ...
Empty file.
35 changes: 35 additions & 0 deletions stubs/google/protobuf/internal/containers.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from google.protobuf.descriptor import Descriptor
from google.protobuf.internal.message_listener import MessageListener
from google.protobuf.message import Message
from typing import (
MutableSequence, Sequence, TypeVar, Generic, Any, Iterator, Iterable,
Union, Optional, Callable
)

_T = TypeVar('_T')
class BaseContainer(Generic[_T], MutableSequence[_T]):
def __init__(self, message_listener: MessageListener) -> None: ...
def __len__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def sort(self, *, key: Optional[Callable[[_T], Any]] = ..., reverse: bool = ...) -> None: ...

class RepeatedScalarFieldContainer(Generic[_T], BaseContainer[_T]):
def __init__(self, message_listener: MessageListener, message_descriptor: Descriptor) -> None: ...
def MergeFrom(self, other: RepeatedScalarFieldContainer[_T]) -> None: ...

class RepeatedCompositeFieldContainer(Generic[_T], BaseContainer[_T]):
def __init__(self, message_listener: MessageListener, type_checker: Any) -> None: ...
def add(self, **kwargs: Any) -> _T: ...
def MergeFrom(self, other: RepeatedCompositeFieldContainer[_T]) -> None: ...

# Classes not yet typed
class Mapping(Any):
pass
class MutableMapping(Mapping):
pass
class ScalarMap(MutableMapping):
pass
class MessageMap(MutableMapping):
pass
30 changes: 30 additions & 0 deletions stubs/google/protobuf/internal/decoder.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any

def ReadTag(buffer, pos): ...
def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): ...

Int32Decoder = ... # type: Any
Int64Decoder = ... # type: Any
UInt32Decoder = ... # type: Any
UInt64Decoder = ... # type: Any
SInt32Decoder = ... # type: Any
SInt64Decoder = ... # type: Any
Fixed32Decoder = ... # type: Any
Fixed64Decoder = ... # type: Any
SFixed32Decoder = ... # type: Any
SFixed64Decoder = ... # type: Any
FloatDecoder = ... # type: Any
DoubleDecoder = ... # type: Any
BoolDecoder = ... # type: Any

def StringDecoder(field_number, is_repeated, is_packed, key, new_default): ...
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): ...
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): ...
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): ...

MESSAGE_SET_ITEM_TAG = ... # type: Any

def MessageSetItemDecoder(extensions_by_number): ...
def MapDecoder(field_descriptor, new_default, is_message_map): ...

SkipField = ... # type: Any
34 changes: 34 additions & 0 deletions stubs/google/protobuf/internal/encoder.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any

Int32Sizer = ... # type: Any
UInt32Sizer = ... # type: Any
SInt32Sizer = ... # type: Any
Fixed32Sizer = ... # type: Any
Fixed64Sizer = ... # type: Any
BoolSizer = ... # type: Any

def StringSizer(field_number, is_repeated, is_packed): ...
def BytesSizer(field_number, is_repeated, is_packed): ...
def GroupSizer(field_number, is_repeated, is_packed): ...
def MessageSizer(field_number, is_repeated, is_packed): ...
def MessageSetItemSizer(field_number): ...
def MapSizer(field_descriptor): ...
def TagBytes(field_number, wire_type): ...

Int32Encoder = ... # type: Any
UInt32Encoder = ... # type: Any
SInt32Encoder = ... # type: Any
Fixed32Encoder = ... # type: Any
Fixed64Encoder = ... # type: Any
SFixed32Encoder = ... # type: Any
SFixed64Encoder = ... # type: Any
FloatEncoder = ... # type: Any
DoubleEncoder = ... # type: Any

def BoolEncoder(field_number, is_repeated, is_packed): ...
def StringEncoder(field_number, is_repeated, is_packed): ...
def BytesEncoder(field_number, is_repeated, is_packed): ...
def GroupEncoder(field_number, is_repeated, is_packed): ...
def MessageEncoder(field_number, is_repeated, is_packed): ...
def MessageSetItemEncoder(field_number): ...
def MapEncoder(field_descriptor): ...
11 changes: 11 additions & 0 deletions stubs/google/protobuf/internal/enum_type_wrapper.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any, List, Tuple

class EnumTypeWrapper(object):
def __init__(self, enum_type: Any) -> None: ...
def Name(self, number: int) -> str: ...
def Value(self, name: str) -> int: ...
def keys(self) -> List[str]: ...
def values(self) -> List[int]: ...

@classmethod
def items(cls) -> List[Tuple[str, int]]: ...
5 changes: 5 additions & 0 deletions stubs/google/protobuf/internal/message_listener.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class MessageListener(object):
def Modified(self) -> None: ...

class NullMessageListener(MessageListener):
def Modified(self) -> None: ...
Loading

0 comments on commit ba86ec2

Please sign in to comment.