Skip to content

Commit

Permalink
Adding support for stream_management to escalus_ws
Browse files Browse the repository at this point in the history
  • Loading branch information
JanuszJakubiec committed Jun 4, 2024
1 parent 6f6e688 commit ca24321
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 58 deletions.
8 changes: 4 additions & 4 deletions src/escalus_tcp.erl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ separate_ack_requests({true, H0, inactive}, Stanzas) ->
Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)],
Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)],

case {length(Enabled),length(Resumed)} of
case {length(Enabled), length(Resumed)} of
%% Enabled SM: set the H param to 0 and activate counter.
{1,0} -> {{true, 0, active}, [], Stanzas};

Expand All @@ -450,12 +450,12 @@ separate_ack_requests({true, H0, active}, Stanzas) ->

make_ack(H) -> {escalus_stanza:sm_ack(H), H}.

reply_to_ack_requests({false,H,A}, _, _) -> {false, H, A};
reply_to_ack_requests({true,H,inactive}, _, _) -> {true, H, inactive};
reply_to_ack_requests({false, H, A}, _, _) -> {false, H, A};
reply_to_ack_requests({true, H, inactive}, _, _) -> {true, H, inactive};
reply_to_ack_requests({true, H0, active}, Acks, State) ->
{true,
% TODO: Maybe compress here?
lists:foldl(fun({Ack,H}, _) -> raw_send(exml:to_iolist(Ack), State), H end,
lists:foldl(fun({Ack, H}, _) -> raw_send(exml:to_iolist(Ack), State), H end,
H0, Acks),
active}.

Expand Down
198 changes: 144 additions & 54 deletions src/escalus_ws.erl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
send/2,
is_connected/1,
reset_parser/1,
get_sm_h/1,
set_sm_h/2,
use_zlib/1,
upgrade_to_tls/2,
set_filter_predicate/2,
Expand All @@ -40,9 +42,24 @@
-define(SERVER, ?MODULE).

-record(state, {owner, socket, parser, legacy_ws, compress = false,
event_client, filter_pred, stream_ref}).
event_client, sm_state, filter_pred, stream_ref, sent_stanzas = []}).
-type state() :: #state{}.

-type sm_state() :: {boolean(), non_neg_integer(), 'active'|'inactive'}.

-type opts() :: #{
host => string(),
port => pos_integer(),
wspath => string(),
wslegacy => boolean(),
event_client => undefined | escalus_event:event_client(),
ssl => boolean(),
ssl_opts => [ssl:ssl_option()],
ws_upgrade_timeout => pos_integer(),
stream_management => boolean(),
manual_ack => boolean()
}.

%%%===================================================================
%%% API
%%%===================================================================
Expand All @@ -60,6 +77,15 @@ send(Pid, Elem) ->
is_connected(Pid) ->
erlang:is_process_alive(Pid).

-spec get_sm_h(pid()) -> non_neg_integer().
get_sm_h(Pid) ->
gen_server:call(Pid, get_sm_h).

-spec set_sm_h(pid(), non_neg_integer()) -> {ok, non_neg_integer()}.
set_sm_h(Pid, H) ->
gen_server:call(Pid, {set_sm_h, H}).


-spec reset_parser(pid()) -> ok.
reset_parser(Pid) ->
gen_server:cast(Pid, reset_parser).
Expand Down Expand Up @@ -155,31 +181,58 @@ assert_stream_end(StreamEndRep, Props) ->
%%% gen_server callbacks
%%%===================================================================

%% TODO: refactor all opt defaults taken from Args into a default_opts function,
%% so that we know what options the module actually expects
default_options() ->
#{host => "localhost",
port => 5280,
wspath => "/ws-xmpp",
wslegacy => false,
event_client => undefined,
ssl => false,
ssl_opts => [],
ws_upgrade_timeout => 5000,
stream_management => false,
manual_ack => false}.

-spec get_stream_management_opt(opts()) -> sm_state().
get_stream_management_opt(#{stream_management := false}) ->
{false, 0, inactive};
get_stream_management_opt(#{manual_ack := true}) ->
{false, 0, inactive};
get_stream_management_opt(#{stream_management := true, manual_ack := false}) ->
{true, 0, inactive}.

overwrite_default_opts(GivenOpts, DefaultOpts) ->
maps:merge(DefaultOpts, GivenOpts).

do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) ->
TransportOpts = #{transport => tls, protocols => [http],
tls_opts => SSLOpts},
do_connect(Opts, TransportOpts);
do_connect(Opts) ->
do_connect(Opts, #{transport => tcp, protocols => [http]}).

do_connect(#{host := Host, port := Port}, TransportOpts) ->
Host1 = maybe_binary_to_list(Host),
{ok, ConnPid} = gun:open(Host1, Port, TransportOpts),
{ok, http} = gun:await_up(ConnPid),
ConnPid.

-spec init(list()) -> {ok, state()}.
init([Args, Owner]) ->
Host = get_host(Args, "localhost"),
Port = get_port(Args, 5280),
Resource = get_resource(Args, "/ws-xmpp"),
LegacyWS = get_legacy_ws(Args, false),
EventClient = proplists:get_value(event_client, Args),
SSL = proplists:get_value(ssl, Args, false),
SSLOpts = proplists:get_value(ssl_opts, Args, []),
init([Opts, Owner]) ->
Opts1 = overwrite_default_opts(maps:from_list(Opts), default_options()),
#{wspath := Resource,
wslegacy := LegacyWS,
event_client := EventClient,
ws_upgrade_timeout := Timeout} = Opts1,
SM = get_stream_management_opt(Opts1),
Resource1 = maybe_binary_to_list(Resource),

ConnPid = do_connect(Opts1),

%% Disable http2 in protocols
TransportOpts = case SSL of
true ->
#{transport => tls, protocols => [http],
tls_opts => SSLOpts};
_ ->
#{transport => tcp, protocols => [http]}
end,
{ok, ConnPid} = gun:open(Host, Port, TransportOpts),
{ok, http} = gun:await_up(ConnPid),
WSUpgradeHeaders = [{<<"sec-websocket-protocol">>, <<"xmpp">>}],
StreamRef = gun:ws_upgrade(ConnPid, Resource, WSUpgradeHeaders,
StreamRef = gun:ws_upgrade(ConnPid, Resource1, WSUpgradeHeaders,
#{protocols => [{<<"xmpp">>, gun_ws_h}]}),
Timeout = get_option(ws_upgrade_timeout, Args, 5000),
wait_for_ws_upgrade(ConnPid, StreamRef, Timeout),
ParserOpts = case LegacyWS of
true -> [];
Expand All @@ -190,6 +243,7 @@ init([Args, Owner]) ->
socket = ConnPid,
parser = Parser,
legacy_ws = LegacyWS,
sm_state = SM,
event_client = EventClient,
stream_ref = StreamRef}}.

Expand All @@ -208,6 +262,11 @@ wait_for_ws_upgrade(ConnPid, StreamRef, Timeout) ->

-spec handle_call(term(), {pid(), term()}, state()) ->
{reply, term(), state()} | {stop, normal, ok, state()}.
handle_call(get_sm_h, _From, #state{sm_state = {_, H, _}} = State) ->
{reply, H, State};
handle_call({set_sm_h, H}, _From, #state{sm_state = {A, _OldH, S}} = State) ->
NewState = State#state{sm_state={A, H, S}},
{reply, {ok, H}, NewState};
handle_call(use_zlib, _, #state{parser = Parser} = State) ->
Zin = zlib:open(),
Zout = zlib:open(),
Expand Down Expand Up @@ -270,48 +329,81 @@ handle_data(Data, State = #state{parser = Parser,
Decompressed = iolist_to_binary(zlib:inflate(Zin, Data)),
exml_stream:parse(Parser, Decompressed)
end,
NewState = State#state{parser = NewParser},
escalus_connection:maybe_forward_to_owner(NewState#state.filter_pred,
NewState,
FwdState = State#state{parser = NewParser, sent_stanzas = []},
escalus_connection:maybe_forward_to_owner(FwdState#state.filter_pred,
FwdState,
Stanzas,
fun forward_to_owner/3, Timestamp),
case lists:filter(fun(Stanza) -> is_stream_end(Stanza, State) end, Stanzas) of
[] -> {noreply, NewState};
_ -> {stop, normal, NewState}
end.
fun forward_to_owner/3, Timestamp).

-spec is_stream_end(exml_stream:element(), state()) -> boolean().
is_stream_end(#xmlstreamend{}, #state{legacy_ws = true}) -> true;
is_stream_end(#xmlel{name = <<"close">>}, #state{legacy_ws = false}) -> true;
is_stream_end(_, _) -> false.

forward_to_owner(Stanzas, #state{owner = Owner,
event_client = EventClient}, Timestamp) ->
forward_to_owner(Stanzas0, #state{owner = Owner,
sm_state = SM0,
event_client = EventClient} = State, Timestamp) ->
{SM1, AckRequests, StanzasNoRs} = separate_ack_requests(SM0, Stanzas0),
reply_to_ack_requests(SM1, AckRequests, State),

lists:foreach(fun(Stanza) ->
escalus_event:incoming_stanza(EventClient, Stanza),
Owner ! escalus_connection:stanza_msg(Stanza,
#{recv_timestamp => Timestamp})
end, Stanzas).
escalus_event:incoming_stanza(EventClient, Stanza),
Owner ! escalus_connection:stanza_msg(Stanza, #{recv_timestamp => Timestamp})
end, StanzasNoRs),

case lists:keyfind(xmlstreamend, 1, StanzasNoRs) of
false -> ok;
_ -> gen_server:cast(self(), stop)
end,

{noreply, State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}}.

separate_ack_requests({false, H0, A}, Stanzas) ->
%% Don't keep track of H
{{false, H0, A}, [], Stanzas};
separate_ack_requests({true, H0, inactive}, Stanzas) ->
Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)],
Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)],

case {length(Enabled), length(Resumed)} of
%% Enabled SM: set the H param to 0 and activate counter.
{1,0} -> {{true, 0, active}, [], Stanzas};

%% Resumed SM: keep the H param and activate counter.
{0,1} -> {{true, H0, active}, [], Stanzas};

%% No new SM state: continue as usual
{0,0} -> {{true, H0, inactive}, [], Stanzas}
end;
separate_ack_requests({true, H0, active}, Stanzas) ->
%% Count H and construct appropriate acks
F = fun(Stanza, {H, Acks, NonAckRequests}) ->
case escalus_pred:is_sm_ack_request(Stanza) of
true -> {H, [make_ack(H)|Acks], NonAckRequests};
false -> {H+1, Acks, [Stanza|NonAckRequests]}
end
end,
{H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas),
{{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}.

make_ack(H) -> {escalus_stanza:sm_ack(H), H}.

reply_to_ack_requests({false, H, A}, _, _) ->
{false, H, A};
reply_to_ack_requests({true, H, inactive}, _, _) ->
{true, H, inactive};
reply_to_ack_requests({true, H0, active}, Acks, State) ->
{true,
lists:foldl(fun({Ack, H}, _) ->
Ack1 = exml:to_iolist(Ack),
gun:ws_send(State#state.socket, State#state.stream_ref, {text, Ack1}),
H
end, H0, Acks),
active}.

common_terminate(_Reason, #state{parser = Parser}) ->
exml_stream:free_parser(Parser).

-spec get_port(list(), inet:port_number()) -> inet:port_number().
get_port(Args, Default) ->
get_option(port, Args, Default).

-spec get_host(list(), string()) -> string().
get_host(Args, Default) ->
maybe_binary_to_list(get_option(host, Args, Default)).

-spec get_resource(list(), string()) -> string().
get_resource(Args, Default) ->
maybe_binary_to_list(get_option(wspath, Args, Default)).

-spec get_legacy_ws(list(), boolean()) -> boolean().
get_legacy_ws(Args, Default) ->
get_option(wslegacy, Args, Default).

-spec maybe_binary_to_list(binary() | string()) -> string().
maybe_binary_to_list(B) when is_binary(B) -> binary_to_list(B);
maybe_binary_to_list(S) when is_list(S) -> S.
Expand All @@ -336,5 +428,3 @@ close_compression_streams({zlib, {Zin, Zout}}) ->
ok = zlib:close(Zin),
ok = zlib:close(Zout)
end.


0 comments on commit ca24321

Please sign in to comment.