Skip to content

Commit

Permalink
Optimize protocol read interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
David Hull committed Jun 7, 2016
1 parent 6c9b0d7 commit 5dedd22
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 68 deletions.
55 changes: 28 additions & 27 deletions src/ox_thrift_protocol.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ encode_struct ([], _Record, _I) ->
%% `MessageType' is `?tMessageType_CALL', `?tMessageType_ONEWAY', `?tMessageReply', or `?tMessageException'.
decode_message (ServiceModule, Buffer0) ->
{Buffer1, #protocol_message_begin{name=FunctionBin, type=ThriftMessageType, seqid=SeqId}} =
read(Buffer0, message_begin),
read_message_begin(Buffer0),
Function = binary_to_atom(FunctionBin, latin1),
case ThriftMessageType of
?tMessageType_CALL ->
Expand All @@ -239,7 +239,7 @@ decode_message (ServiceModule, Buffer0) ->
MessageSpec = ?tApplicationException_Structure,
{Buffer2, Args} = decode_record(Buffer1, application_exception, MessageSpec)
end,
{<<>>, ok} = read(Buffer2, message_end),
<<>> = read_message_end(Buffer2),
%% io:format(standard_error, "decode\nspec ~p\nargs ~p\n", [ MessageSpec, Args ]),
{Function, MessageType, SeqId, Args}.

Expand Down Expand Up @@ -267,25 +267,25 @@ decode (Buffer, {struct, {Schema, StructName}})
decode_record(Buffer, StructName, Schema:struct_info(StructName));

decode (Buffer0, _T={list, Type}) ->
{Buffer1, #protocol_list_begin{etype=EType, size=Size}} = read(Buffer0, list_begin),
{Buffer1, #protocol_list_begin{etype=EType, size=Size}} = read_list_begin(Buffer0),
?VALIDATE_TYPE(Type, EType, [ Buffer0, _T ]),
{Buffer2, List} = decode_list(Buffer1, Type, [], Size),
{Buffer3, ok} = read(Buffer2, list_end),
Buffer3 = read_list_end(Buffer2),
{Buffer3, List};

decode (Buffer0, _T={map, KeyType, ValType}) ->
{Buffer1, #protocol_map_begin{ktype=KType, vtype=VType, size=Size}} = read(Buffer0, map_begin),
{Buffer1, #protocol_map_begin{ktype=KType, vtype=VType, size=Size}} = read_map_begin(Buffer0),
?VALIDATE_TYPE(KeyType, KType, [ Buffer0, _T ]),
?VALIDATE_TYPE(ValType, VType, [ Buffer0, _T ]),
{Buffer2, List} = decode_map(Buffer1, {KeyType, ValType}, [], Size),
{Buffer3, ok} = read(Buffer2, map_end),
Buffer3 = read_map_end(Buffer2),
{Buffer3, dict:from_list(List)};

decode (Buffer0, _T={set, Type}) ->
{Buffer1, #protocol_set_begin{etype=EType, size=Size}} = read(Buffer0, set_begin),
{Buffer1, #protocol_set_begin{etype=EType, size=Size}} = read_set_begin(Buffer0),
?VALIDATE_TYPE(Type, EType, [ Buffer0, _T ]),
{Buffer2, List} = decode_set(Buffer1, Type, [], Size),
{Buffer3, ok} = read(Buffer2, set_end),
Buffer3 = read_set_end(Buffer2),
{Buffer3, sets:from_list(List)};

decode (Buffer0, Type) when is_atom(Type) ->
Expand All @@ -297,19 +297,19 @@ decode (Buffer0, Type) when is_atom(Type) ->
decode_record (Buffer0, Name, {struct, StructDef})
when is_atom(Name), is_list(StructDef) ->
%% Decode a record from a struct definition.
{Buffer1, ok} = read(Buffer0, struct_begin),
Buffer1 = read_struct_begin(Buffer0),
%% If we were going to handle field defaults we could create the initialize
%% here. It might be better to wait until after the struct is parsed,
%% however, to avoid unnecessarily creating initializers for fields that
%% don't need them. @@
{Buffer2, Record} = decode_struct(Buffer1, StructDef, [ {1, Name} ]),
{Buffer3, ok} = read(Buffer2, struct_end),
Buffer3 = read_struct_end(Buffer2),
{Buffer3, Record}.


-spec decode_struct(BufferIn::binary(), FieldList::list(), Acc::list()) -> {binary(), tuple()}.
decode_struct (Buffer0, FieldList, Acc) ->
{Buffer1, #protocol_field_begin{type=FieldTId, id=FieldId}} = read(Buffer0, field_begin),
{Buffer1, #protocol_field_begin{type=FieldTId, id=FieldId}} = read_field_begin(Buffer0),
case FieldTId of
field_stop ->
Record = erlang:make_tuple(length(FieldList)+1, undefined, Acc),
Expand All @@ -319,12 +319,12 @@ decode_struct (Buffer0, FieldList, Acc) ->
{FieldTypeAtom, N} ->
?VALIDATE_TYPE(FieldTypeAtom, FieldTId, [ Buffer0, FieldList, Acc ]),
{Buffer2, Val} = decode(Buffer1, FieldTypeAtom),
{Buffer3, ok} = read(Buffer2, field_end),
Buffer3 = read_field_end(Buffer2),
decode_struct(Buffer3, FieldList, [ {N, Val} | Acc ]);
false ->
%% io:format("field ~p not found in ~p\n", [ FieldId, FieldList ]),
{Buffer2, _} = skip(Buffer1, FieldTId),
{Buffer3, ok} = read(Buffer2, field_end),
Buffer2 = skip(Buffer1, FieldTId),
Buffer3 = read_field_end(Buffer2),
decode_struct(Buffer3, FieldList, Acc)
end
end.
Expand Down Expand Up @@ -354,51 +354,52 @@ decode_set (Buffer0, EType, Acc, N) ->

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

-spec skip(Buffer0::binary(), Type::atom()) -> {Buffer1::binary(), Dummy::term()}.
-spec skip(Buffer0::binary(), Type::atom()) -> Buffer1::binary().
skip (Buffer0, struct) ->
{Buffer1, _} = read(Buffer0, struct_begin),
Buffer1 = read_struct_begin(Buffer0),
Buffer2 = skip_struct(Buffer1),
read(Buffer2, struct_end);
read_struct_end(Buffer2);

skip (Buffer0, list) ->
{Buffer1, #protocol_list_begin{etype=EType, size=Size}} = read(Buffer0, list_begin),
{Buffer1, #protocol_list_begin{etype=EType, size=Size}} = read_list_begin(Buffer0),
Buffer2 = foldn(fun (BufferL0) ->
{BufferL1, _} = decode(BufferL0, EType),
BufferL1
end, Buffer1, Size),
read(Buffer2, list_end);
read_list_end(Buffer2);

skip (Buffer0, map) ->
{Buffer1, #protocol_map_begin{ktype=KType, vtype=VType, size=Size}} = read(Buffer0, map_begin),
{Buffer1, #protocol_map_begin{ktype=KType, vtype=VType, size=Size}} = read_map_begin(Buffer0),
Buffer2 = foldn(fun (BufferL0) ->
{BufferL1, _} = decode(BufferL0, KType),
{BufferL2, _} = decode(BufferL1, VType),
BufferL2
end, Buffer1, Size),
read(Buffer2, map_end);
read_map_end(Buffer2);

skip (Buffer0, set) ->
{Buffer1, #protocol_set_begin{etype=EType, size=Size}} = read(Buffer0, set_begin),
{Buffer1, #protocol_set_begin{etype=EType, size=Size}} = read_set_begin(Buffer0),
Buffer2 = foldn(fun (BufferL0) ->
{BufferL1, _} = decode(BufferL0, EType),
BufferL1
end, Buffer1, Size),
read(Buffer2, set_end);
read_set_end(Buffer2);

skip (Buffer0, Type) when is_atom(Type) ->
%% Skip the basic types.
read(Buffer0, Type).
{Buffer, _Value} = read(Buffer0, Type),
Buffer.


-spec skip_struct (Buffer0::binary()) -> Buffer1::binary().
skip_struct (Buffer0) ->
{Buffer1, #protocol_field_begin{type=Type}} = read(Buffer0, field_begin),
{Buffer1, #protocol_field_begin{type=Type}} = read_field_begin(Buffer0),
case Type of
field_stop ->
Buffer1;
_ ->
{Buffer2, _} = skip(Buffer1, Type),
{Buffer3, ok} = read(Buffer2, field_end),
Buffer2 = skip(Buffer1, Type),
Buffer3 = read_field_end(Buffer2),
skip_struct(Buffer3)
end.

Expand Down
90 changes: 49 additions & 41 deletions src/ox_thrift_protocol_binary.erl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
, write_struct_begin/1
, write_struct_end/0
, write/2
, read_message_begin/1
, read_message_end/1
, read_struct_begin/1
, read_struct_end/1
, read_field_begin/1
, read_field_end/1
, read_map_begin/1
, read_map_end/1
, read_list_begin/1
, read_list_end/1
, read_set_begin/1
, read_set_end/1
, read/2 ]}).
%% -compile(inline_list_funcs).

Expand Down Expand Up @@ -150,7 +162,7 @@ read(DataIn, Type) ->
-endif. %% ! DEBUG_READ


?READ (Data0, message_begin) ->
read_message_begin (Data0) ->
Version = binary_part(Data0, {0, 2}),
case Version of
?VERSION_1 ->
Expand All @@ -163,27 +175,27 @@ read(DataIn, Type) ->
_ ->
%% Unexpected version number.
error({bad_binary_protocol_version, Version})
end;
end.

?READ (Data, message_end) ->
{Data, ok};
read_message_end (Data) ->
Data.

?READ (Data, struct_begin) ->
{Data, ok};
read_struct_begin (Data) ->
Data.

?READ (Data, struct_end) ->
{Data, ok};
read_struct_end (Data) ->
Data.

?READ (Data0, field_begin) ->
read_field_begin (Data0) ->
case Data0 of
<<?tType_STOP:8/big-signed, Data1/binary>> ->
{Data1, #protocol_field_begin{type = field_stop}};
<<Type:8/big-signed, Id:16/big-signed, Data1/binary>> ->
{Data1, #protocol_field_begin{type = wire_to_term(Type), id = Id}}
end;
end.

?READ (Data, field_end) ->
{Data, ok};
read_field_end (Data) ->
Data.

%% This isn't necessary, since we never explicitly read a `field_stop', we
%% just find it when trying to read a `field_begin'.
Expand All @@ -192,26 +204,27 @@ read(DataIn, Type) ->
%% {?tType_STOP, Data1} = read(?tType_BYTE, Data0),
%% {ok, Data1};

?READ (Data0, map_begin) ->
read_map_begin (Data0) ->
<<KType:8/big-signed, VType:8/big-signed, Size:32/big-signed, Data1/binary>> = Data0,
{Data1, #protocol_map_begin{ktype = wire_to_term(KType), vtype = wire_to_term(VType), size = Size}};
{Data1, #protocol_map_begin{ktype = wire_to_term(KType), vtype = wire_to_term(VType), size = Size}}.

?READ (Data, map_end) ->
{Data, ok};
read_map_end (Data) ->
Data.

?READ (Data0, list_begin) ->
read_list_begin (Data0) ->
<<EType:8/big-signed, Size:32/big-signed, Data1/binary>> = Data0,
{Data1, #protocol_list_begin{etype = wire_to_term(EType), size = Size}};
{Data1, #protocol_list_begin{etype = wire_to_term(EType), size = Size}}.

?READ (Data, list_end) ->
{Data, ok};
read_list_end (Data) ->
Data.

?READ (Data0, set_begin) ->
read_set_begin (Data0) ->
<<EType:8/big-signed, Size:32/big-signed, Data1/binary>> = Data0,
{Data1, #protocol_set_begin{etype = wire_to_term(EType), size = Size}};
{Data1, #protocol_set_begin{etype = wire_to_term(EType), size = Size}}.

read_set_end (Data) ->
Data.

?READ (Data, set_end) ->
{Data, ok};

?READ (Data0, bool) ->
<<Bool:8/big-signed, Data1/binary>> = Data0,
Expand Down Expand Up @@ -239,12 +252,7 @@ read(DataIn, Type) ->

?READ (Data0, string) ->
<<Size:32/big-signed, String:Size/binary, Data1/binary>> = Data0,
{Data1, String};

?READ (Data0, ui32) ->
%% Used for reading the message version header.
<<Val:32/integer-unsigned-big, Data1/binary>> = Data0,
{Data1, Val}.
{Data1, String}.


%% -spec read_data(DataIn::binary(), Size::non_neg_integer()) -> {DataOut::binary(), Value::binary()}.
Expand All @@ -269,48 +277,48 @@ message_test () ->
Type = ?tMessageType_CALL,
SeqId = 16#7FFFFFF0,
P = #protocol_message_begin{name = Name, type = Type, seqid = SeqId},
?assertEqual({<<>>, P}, read(iolist_to_binary(write_message_begin(Name, Type, SeqId)), message_begin)),
?assertEqual({<<>>, P}, read_message_begin(iolist_to_binary(write_message_begin(Name, Type, SeqId)))),

%% New-style message header.
?assertEqual({<<>>, P}, read(<<?VERSION_1/binary, 0, Type, NameLen:32/big, Name/binary, SeqId:32/big>>, message_begin)),
?assertEqual({<<>>, P}, read_message_begin(<<?VERSION_1/binary, 0, Type, NameLen:32/big, Name/binary, SeqId:32/big>>)),

%% Old-style message header.
?assertEqual({<<>>, P}, read(<<NameLen:32/big, Name/binary, Type, SeqId:32/big>>, message_begin)).
?assertEqual({<<>>, P}, read_message_begin(<<NameLen:32/big, Name/binary, Type, SeqId:32/big>>)).

field_test () ->
Name = <<"field">>,
Type = i32,
Id = 16#7FF0,
%% Name is not sent in binary protocol.
P = #protocol_field_begin{name = undefined, type = Type, id = Id},
?assertEqual({<<>>, P}, read(iolist_to_binary(write_field_begin(Name, Type, Id)), field_begin)),
?assertEqual({<<>>, P}, read_field_begin(iolist_to_binary(write_field_begin(Name, Type, Id)))),

?assertEqual({<<>>, ok}, read(iolist_to_binary(write_field_end()), field_end)).
?assertEqual(<<>>, read_field_end(iolist_to_binary(write_field_end()))).

map_test () ->
KType = byte,
VType = string,
Size = 16#7FFFFFF1,
P = #protocol_map_begin{ktype = KType, vtype = VType, size = Size},
?assertEqual({<<>>, P}, read(iolist_to_binary(write_map_begin(KType, VType, Size)), map_begin)),
?assertEqual({<<>>, P}, read_map_begin(iolist_to_binary(write_map_begin(KType, VType, Size)))),

?assertEqual({<<>>, ok}, read(iolist_to_binary(write_map_end()), map_end)).
?assertEqual(<<>>, read_map_end(iolist_to_binary(write_map_end()))).

list_test () ->
EType = byte,
Size = 16#7FFFFFF2,
P = #protocol_list_begin{etype = EType, size = Size},
?assertEqual({<<>>, P}, read(iolist_to_binary(write_list_begin(EType, Size)), list_begin)),
?assertEqual({<<>>, P}, read_list_begin(iolist_to_binary(write_list_begin(EType, Size)))),

?assertEqual({<<>>, ok}, read(iolist_to_binary(write_list_end()), list_end)).
?assertEqual(<<>>, read_list_end(iolist_to_binary(write_list_end()))).

set_test () ->
EType = byte,
Size = 16#7FFFFFF3,
P = #protocol_set_begin{etype = EType, size = Size},
?assertEqual({<<>>, P}, read(iolist_to_binary(write_set_begin(EType, Size)), set_begin)),
?assertEqual({<<>>, P}, read_set_begin(iolist_to_binary(write_set_begin(EType, Size)))),

?assertEqual({<<>>, ok}, read(iolist_to_binary(write_set_end()), set_end)).
?assertEqual(<<>>, read_set_end(iolist_to_binary(write_set_end()))).

basic_test () ->
B = 16#7F,
Expand Down

0 comments on commit 5dedd22

Please sign in to comment.