Skip to content

Commit

Permalink
Merge pull request matrix-org#3638 from matrix-org/rav/refactor_feder…
Browse files Browse the repository at this point in the history
…ation_client_exception_handling

Factor out exception handling in federation_client
  • Loading branch information
richvdh authored Aug 2, 2018
2 parents 704c3e6 + 38b98e5 commit bdae8f2
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 129 deletions.
1 change: 1 addition & 0 deletions changelog.d/3638.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Factor out exception handling in federation_client
277 changes: 148 additions & 129 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
PDU_RETRY_TIME_MS = 1 * 60 * 1000


class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
pass


class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
Expand Down Expand Up @@ -458,6 +465,61 @@ def get_event_auth(self, destination, room_id, event_id):
defer.returnValue(signed_auth)

@defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
"""Try an operation on a series of servers, until it succeeds
Args:
description (unicode): description of the operation we're doing, for logging
destinations (Iterable[unicode]): list of server_names to try
callback (callable): Function to run for each server. Passed a single
argument: the server_name to try. May return a deferred.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
reraised.
Otherwise, if the callback raises an Exception the error is logged and the
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
Returns:
The [Deferred] result of callback, if it succeeds
Raises:
CodeMessageException if the chosen remote server returns a 300/400 code.
RuntimeError if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue

try:
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
logger.warn(
"Failed to %s via %s: %s",
description, destination, e,
)
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to %s via %s: %i %s",
description, destination, e.code, e.message,
)
except Exception:
logger.warn(
"Failed to %s via %s",
description, destination, exc_info=1,
)

raise RuntimeError("Failed to %s via any server", description)

def make_membership_event(self, destinations, room_id, user_id, membership,
content={},):
"""
Expand Down Expand Up @@ -492,50 +554,35 @@ def make_membership_event(self, destinations, room_id, user_id, membership,
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
for destination in destinations:
if destination == self.server_name:
continue

try:
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership
)
@defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership
)

pdu_dict = ret["event"]
pdu_dict = ret["event"]

logger.debug("Got response to make_%s: %s", membership, pdu_dict)
logger.debug("Got response to make_%s: %s", membership, pdu_dict)

pdu_dict["content"].update(content)
pdu_dict["content"].update(content)

# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []

ev = builder.EventBuilder(pdu_dict)
ev = builder.EventBuilder(pdu_dict)

defer.returnValue(
(destination, ev)
)
break
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
defer.returnValue(
(destination, ev)
)

raise RuntimeError("Failed to send to any server.")
return self._try_destination_list(
"make_" + membership, destinations, send_request,
)

@defer.inlineCallbacks
def send_join(self, destinations, pdu):
"""Sends a join event to one of a list of homeservers.
Expand All @@ -558,87 +605,70 @@ def send_join(self, destinations, pdu):
Fails with a ``RuntimeError`` if no servers were reachable.
"""

for destination in destinations:
if destination == self.server_name:
continue

try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

logger.debug("Got content: %s", content)
logger.debug("Got content: %s", content)

state = [
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
state = [
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]

auth_chain = [
event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
auth_chain = [
event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]

pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}

valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
)
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, list(pdus.values()),
outlier=True,
)

valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}

# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}

signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]

# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]

auth_chain.sort(key=lambda e: e.depth)
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)

defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException as e:
if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
auth_chain.sort(key=lambda e: e.depth)

raise RuntimeError("Failed to send to any server.")
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
return self._try_destination_list("send_join", destinations, send_request)

@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
Expand All @@ -663,7 +693,6 @@ def send_invite(self, destination, room_id, event_id, pdu):

defer.returnValue(pdu)

@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
Expand All @@ -681,34 +710,24 @@ def send_leave(self, destinations, pdu):
Deferred: resolves to None.
Fails with a ``CodeMessageException`` if the chosen remote server
returns a non-200 code.
returns a 300/400 code.
Fails with a ``RuntimeError`` if no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue

try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)

logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
logger.debug("Got content: %s", content)
defer.returnValue(None)

raise RuntimeError("Failed to send to any server.")
return self._try_destination_list("send_leave", destinations, send_request)

def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None, include_all_networks=False,
Expand Down

0 comments on commit bdae8f2

Please sign in to comment.