Skip to content

Commit

Permalink
Change the way we specify if we require auth or not
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Aug 5, 2016
1 parent 32fc39f commit 597c79b
Showing 1 changed file with 55 additions and 40 deletions.
95 changes: 55 additions & 40 deletions synapse/federation/transport/server.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter

import functools
@@ -60,32 +60,35 @@ def register_servlets(self):
)


class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass


class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass


class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname

# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request):
def authenticate_request(self, request, content):
json_request = {
"method": request.method,
"uri": request.uri,
"destination": self.server_name,
"signatures": {},
}

content = None
origin = None
if content is not None:
json_request["content"] = content

if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
origin = None

def parse_auth_header(header_str):
try:
@@ -103,14 +106,14 @@ def strip_quotes(value):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
raise SynapseError(
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)

auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

if not auth_headers:
raise SynapseError(
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)

@@ -121,7 +124,7 @@ def strip_quotes(value):
json_request["signatures"].setdefault(origin, {})[key] = sig

if not json_request["signatures"]:
raise SynapseError(
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)

@@ -130,40 +133,59 @@ def strip_quotes(value):
logger.info("Request from %s", origin)
request.authenticated_entity = origin

defer.returnValue((origin, content))
defer.returnValue(origin)


class BaseFederationServlet(object):
REQUIRE_AUTH = True

def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler

def _wrap(self, code):
def _wrap(self, func):
authenticator = self.authenticator
ratelimiter = self.ratelimiter

@defer.inlineCallbacks
@functools.wraps(code)
def new_code(request, *args, **kwargs):
@functools.wraps(func)
def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)

try:
(origin, content) = yield authenticator.authenticate_request(request)
origin = yield authenticator.authenticate_request(request, content)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
except:
logger.exception("authenticate_request failed")
raise

if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
response = yield func(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)

defer.returnValue(response)

# Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__
new_func.__self__ = func.__self__

return new_code
return new_func

def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -429,9 +451,10 @@ def on_POST(self, origin, content, query, room_id):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"

REQUIRE_AUTH = False

@defer.inlineCallbacks
def on_POST(self, request):
content = parse_json_object_from_request(request)
def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -453,11 +476,6 @@ def on_POST(self, request):
raise last_exception
defer.returnValue((200, {}))

# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code


class OpenIdUserInfo(BaseFederationServlet):
"""
@@ -478,9 +496,11 @@ class OpenIdUserInfo(BaseFederationServlet):

PATH = "/openid/userinfo"

REQUIRE_AUTH = False

@defer.inlineCallbacks
def on_GET(self, request):
token = parse_string(request, "access_token")
def on_GET(self, origin, content, query):
token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -497,11 +517,6 @@ def on_GET(self, request):

defer.returnValue((200, {"sub": user_id}))

# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code


class PublicRoomList(BaseFederationServlet):
"""

0 comments on commit 597c79b

Please sign in to comment.