Skip to content

Commit

Permalink
Merge pull request matrix-org#1654 from matrix-org/rav/no_more_refres…
Browse files Browse the repository at this point in the history
…h_tokens

Stop generating refresh_tokens
  • Loading branch information
richvdh authored Dec 1, 2016
2 parents 8379a74 + 6841d8f commit 4712000
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 210 deletions.
5 changes: 2 additions & 3 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", "refresh",
type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
user_id (str): The user_id required
Expand Down Expand Up @@ -820,8 +820,7 @@ def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
else:
v.satisfy_general(lambda c: c.startswith("time < "))

# access_tokens and refresh_tokens include a nonce for uniqueness: any
# value is acceptable
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))

v.verify(macaroon, self.hs.config.macaroon_secret_key)
Expand Down
30 changes: 4 additions & 26 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,10 @@ def validate_password_login(self, user_id, password):
return self._check_password(user_id, password)

@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
Creates a new access token for the user with the given user ID.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
Expand All @@ -400,16 +398,13 @@ def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
A tuple of:
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)

# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
Expand All @@ -420,7 +415,7 @@ def get_login_tuple_for_user_id(self, user_id, device_id=None,
user_id, device_id, initial_display_name
)

defer.returnValue((access_token, refresh_token))
defer.returnValue(access_token)

@defer.inlineCallbacks
def check_user_exists(self, user_id):
Expand Down Expand Up @@ -531,13 +526,6 @@ def issue_access_token(self, user_id, device_id=None):
device_id)
defer.returnValue(access_token)

@defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token)

def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
Expand All @@ -551,16 +539,6 @@ def generate_access_token(self, user_id, extra_caveats=None):
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()

def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()

def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
Expand Down
28 changes: 10 additions & 18 deletions synapse/rest/client/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,13 @@ def do_password_login(self, login_submission):
password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
Expand All @@ -161,16 +158,13 @@ def do_token_login(self, login_submission):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
Expand Down Expand Up @@ -207,16 +201,14 @@ def do_jwt_login(self, login_submission):
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name"),
)

result = {
"user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
Expand Down
10 changes: 3 additions & 7 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,7 @@ def _register_email_threepid(self, user_id, threepid, token, bind_email):
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token
and refresh_token.
Allocates device_id if one was not given; also creates access_token.
Args:
(str) user_id: full canonical @user:id
Expand All @@ -386,8 +385,8 @@ def _create_registration_details(self, user_id, params):
"""
device_id = yield self._register_device(user_id, params)

access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
access_token = (
yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
Expand All @@ -397,7 +396,6 @@ def _create_registration_details(self, user_id, params):
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id,
})

Expand Down Expand Up @@ -441,8 +439,6 @@ def _do_guest_registration(self, params):
access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, {
"user_id": user_id,
"device_id": device_id,
Expand Down
26 changes: 3 additions & 23 deletions synapse/rest/client/v2_alpha/tokenrefresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from twisted.internet import defer

from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet

from ._base import client_v2_patterns

Expand All @@ -30,30 +30,10 @@ class TokenRefreshRestServlet(RestServlet):

def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
self.hs = hs
self.store = hs.get_datastore()

@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token
)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}))
except KeyError:
raise SynapseError(400, "Missing required key 'refresh_token'.")
except StoreError:
raise AuthError(403, "Did not recognize refresh token")
raise AuthError(403, "tokenrefresh is no longer supported.")


def register_servlets(hs, http_server):
Expand Down
1 change: 0 additions & 1 deletion synapse/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __init__(self, db_conn, hs):
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
Expand Down
66 changes: 0 additions & 66 deletions synapse/storage/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,6 @@ def add_access_token_to_user(self, user_id, token, device_id=None):
desc="add_access_token_to_user",
)

@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._refresh_tokens_id_gen.get_next()

yield self._simple_insert(
"refresh_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
},
desc="add_refresh_token_to_user",
)

def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
Expand Down Expand Up @@ -353,47 +328,6 @@ def get_user_by_access_token(self, token):
token
)

def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
return self.runInteraction(
"exchange_refresh_token",
self._exchange_refresh_token,
refresh_token,
token_generator
)

def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]

# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.

new_token = token_generator(user_id)
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))

return user_id, new_token, device_id

@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
Expand Down
12 changes: 4 additions & 8 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,18 @@ def test_POST_appservice_registration_valid(self):
self.registration_handler.appservice_register = Mock(
return_value=user_id
)
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
)

(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)

@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
Expand Down Expand Up @@ -126,8 +124,8 @@ def test_POST_user_valid(self):
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
self.auth_handler.get_access_token_for_user_id = Mock(
return_value=token
)
self.device_handler.check_device_registered = \
Mock(return_value=device_id)
Expand All @@ -137,12 +135,10 @@ def test_POST_user_valid(self):
det_data = {
"user_id": user_id,
"access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None)

Expand Down
Loading

0 comments on commit 4712000

Please sign in to comment.