Skip to content

Commit

Permalink
Update protobuf of paddle model (PaddlePaddle#1840)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangjiajun authored Apr 21, 2023
1 parent 51be3fe commit 732d8e8
Showing 1 changed file with 60 additions and 37 deletions.
97 changes: 60 additions & 37 deletions paddle2onnx/proto/p2o_paddle.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ package paddle2onnx.framework.proto;
//
// Serailization and Deserialization codes should be modified in a way
// that supports old versions following the version and compatibility policy.
message Version {
optional int64 version = 1 [default = 0];
}
message Version { optional int64 version = 1 [ default = 0 ]; }

enum AttrType {
INT = 0;
Expand All @@ -40,18 +38,36 @@ enum AttrType {
FLOAT64S = 12;
VAR = 13;
VARS = 14;
FLOAT64 = 15;
SCALAR = 16;
SCALARS = 17;
}

message ProcessMeshDesc {
required int32 id = 1;
required int32 parent_id = 2;
repeated int32 topology = 3;
repeated int32 process_group = 4;

message Complex {
required double r = 1;
required double i = 2;
};

message Scalar {
enum Type {
BOOLEAN = 1;
LONG = 2;
FLOAT64 = 3;
COMPLEX128 = 4;
}
required Type type = 1;

optional bool b = 2;
optional int64 i = 3;
optional double r = 4;
optional Complex c = 5;
};

// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {

message Attr {
required string name = 1;
required AttrType type = 2;
Expand All @@ -70,6 +86,9 @@ message OpDesc {
repeated double float64s = 16;
optional string var_name = 17;
repeated string vars_name = 18;
optional double float64 = 19;
optional Scalar scalar = 20;
repeated Scalar scalars = 21;
};

message Var {
Expand All @@ -81,21 +100,22 @@ message OpDesc {
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [default = false];
optional bool is_target = 5 [ default = false ];
};

// OpProto describes a C++ framework::OperatorBase derived class.
message OpProto {

// VarProto describes the C++ type framework::Variable.
message Var {
required string name = 1;
required string comment = 2;

optional bool duplicable = 3 [default = false];
optional bool intermediate = 4 [default = false];
optional bool dispensable = 5 [default = false];
optional bool extra = 6 [default = false];
optional bool quant = 7 [default = false];
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ];
optional bool extra = 6 [ default = false ];
optional bool quant = 7 [ default = false ];
}

// AttrProto describes the C++ type Attribute.
Expand All @@ -106,10 +126,10 @@ message OpProto {
// If that attribute is generated, it means the Paddle third
// language binding has responsibility to fill that
// attribute. End-User should not set that attribute.
optional bool generated = 4 [default = false];
optional bool extra = 5 [default = false];
optional bool quant = 6 [default = false];
optional bool support_tensor = 7 [default = false];
optional bool generated = 4 [ default = false ];
optional bool extra = 5 [ default = false ];
optional bool quant = 6 [ default = false ];
optional bool support_tensor = 7 [ default = false];
}

required string type = 1;
Expand All @@ -129,7 +149,7 @@ message VarType {
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
// phi::DenseTensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
Expand Down Expand Up @@ -157,45 +177,50 @@ message VarType {
STRINGS = 26;
VOCAB = 27;
FEED_LIST = 28;
// The data type of phi::StringTensor
PSTRING = 29;
// the data type of phi::SparseCooTensor
SPARSE_COO = 30;
// the data type of phi::SparseCsrTensor
SPARSE_CSR = 31;
}

required Type type = 1;

message TensorDesc {
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;

message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [default = 0];
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;

message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [default = 0];
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;

message ReaderDesc {
repeated LoDTensorDesc lod_tensor = 1;
}
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;

message Tuple {
repeated Type element_type = 1;
}
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;

optional TensorDesc string = 8;
optional TensorDesc strings = 9;
optional TensorDesc vocab = 10;
optional TensorDesc sparse_coo = 11;
optional TensorDesc sparse_csr = 12;
}

message VarDesc {

message Attr {
required string name = 1;
required AttrType type = 2;
Expand All @@ -206,12 +231,12 @@ message VarDesc {

required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [default = false];
optional bool persistable = 3 [ default = false ];
// True if the variable is an input data and
// have to check the feed data shape and dtype
optional bool need_check_feed = 4 [default = false];
optional bool is_parameter = 5 [default = false];
optional bool stop_gradient = 6 [default = false];
optional bool need_check_feed = 4 [ default = false ];
optional bool is_parameter = 5 [ default = false ];
optional bool stop_gradient = 6 [ default = false ];
repeated Attr attrs = 7;
}

Expand All @@ -220,14 +245,12 @@ message BlockDesc {
required int32 parent_idx = 2;
repeated VarDesc vars = 3;
repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [default = -1];
optional int32 forward_block_idx = 5 [ default = -1 ];
}

// In some cases, Paddle may perform operator definition iterations,
// and the operator uses OpVersionMap for compatibility testing.
message OpVersion {
required int32 version = 1;
}
message OpVersion { required int32 version = 1; }
message OpVersionMap {
message OpVersionPair {
required string op_name = 1;
Expand All @@ -242,7 +265,7 @@ message OpVersionMap {
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc {
reserved 2, 3; // For backward compatibility.
reserved 2, 3; // For backward compatibility.
repeated BlockDesc blocks = 1;
optional Version version = 4;
optional OpVersionMap op_version_map = 5;
Expand Down

0 comments on commit 732d8e8

Please sign in to comment.