Skip to content

Commit

Permalink
Replace hard-coded constants with Protobuf enums (onnx#4638)
Browse files Browse the repository at this point in the history
Changed hard-coded type constants in `PrimitiveTypeNameMap` or `AttributeTypeNameMap` into
enums generated by Protobuf, (e.g. 1 -> TensorProto_DataType_FLOAT)

Signed-off-by: Chanjung Kim <[email protected]>

Signed-off-by: Chanjung Kim <[email protected]>
Co-authored-by: G. Ramalingam <[email protected]>
  • Loading branch information
paxbun and gramalingam authored Nov 29, 2022
1 parent 3fc9ad3 commit 37ba77e
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,22 @@ class StringIntMap {
class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
public:
PrimitiveTypeNameMap() : StringIntMap() {
map_["float"] = 1;
map_["uint8"] = 2;
map_["int8"] = 3;
map_["uint16"] = 4;
map_["int16"] = 5;
map_["int32"] = 6;
map_["int64"] = 7;
map_["string"] = 8;
map_["bool"] = 9;
map_["float16"] = 10;
map_["double"] = 11;
map_["uint32"] = 12;
map_["uint64"] = 13;
map_["complex64"] = 14;
map_["complex128"] = 15;
map_["bfloat16"] = 16;
map_["float"] = TensorProto_DataType_FLOAT;
map_["uint8"] = TensorProto_DataType_UINT8;
map_["int8"] = TensorProto_DataType_INT8;
map_["uint16"] = TensorProto_DataType_UINT16;
map_["int16"] = TensorProto_DataType_INT16;
map_["int32"] = TensorProto_DataType_INT32;
map_["int64"] = TensorProto_DataType_INT64;
map_["string"] = TensorProto_DataType_STRING;
map_["bool"] = TensorProto_DataType_BOOL;
map_["float16"] = TensorProto_DataType_FLOAT16;
map_["double"] = TensorProto_DataType_DOUBLE;
map_["uint32"] = TensorProto_DataType_UINT32;
map_["uint64"] = TensorProto_DataType_UINT64;
map_["complex64"] = TensorProto_DataType_COMPLEX64;
map_["complex128"] = TensorProto_DataType_COMPLEX128;
map_["bfloat16"] = TensorProto_DataType_BFLOAT16;
}

static bool IsTypeName(const std::string& dtype) {
Expand All @@ -98,20 +98,20 @@ class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
public:
AttributeTypeNameMap() : StringIntMap() {
map_["float"] = 1;
map_["int"] = 2;
map_["string"] = 3;
map_["tensor"] = 4;
map_["graph"] = 5;
map_["sparse_tensor"] = 11;
map_["type_proto"] = 13;
map_["floats"] = 6;
map_["ints"] = 7;
map_["strings"] = 8;
map_["tensors"] = 9;
map_["graphs"] = 10;
map_["sparse_tensors"] = 12;
map_["type_protos"] = 14;
map_["float"] = AttributeProto_AttributeType_FLOAT;
map_["int"] = AttributeProto_AttributeType_INT;
map_["string"] = AttributeProto_AttributeType_STRING;
map_["tensor"] = AttributeProto_AttributeType_TENSOR;
map_["graph"] = AttributeProto_AttributeType_GRAPH;
map_["sparse_tensor"] = AttributeProto_AttributeType_SPARSE_TENSOR;
map_["type_proto"] = AttributeProto_AttributeType_TYPE_PROTO;
map_["floats"] = AttributeProto_AttributeType_FLOATS;
map_["ints"] = AttributeProto_AttributeType_INTS;
map_["strings"] = AttributeProto_AttributeType_STRINGS;
map_["tensors"] = AttributeProto_AttributeType_TENSORS;
map_["graphs"] = AttributeProto_AttributeType_GRAPHS;
map_["sparse_tensors"] = AttributeProto_AttributeType_SPARSE_TENSORS;
map_["type_protos"] = AttributeProto_AttributeType_TYPE_PROTOS;
}
};

Expand Down Expand Up @@ -428,4 +428,4 @@ class OnnxParser : public ParserBase {
bool NextIsType();
};

} // namespace ONNX_NAMESPACE
} // namespace ONNX_NAMESPACE

0 comments on commit 37ba77e

Please sign in to comment.