Skip to content

Commit

Permalink
Store key validity time in the storage layer
Browse files Browse the repository at this point in the history
This is a first step to checking that the key is valid at the required moment.

The idea here is that, rather than passing VerifyKey objects in and out of the
storage layer, we instead pass FetchKeyResult objects, which simply wrap the
VerifyKey and add a valid_until_ts field.
  • Loading branch information
richvdh committed May 23, 2019
1 parent 84660d9 commit b75537b
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 46 deletions.
1 change: 1 addition & 0 deletions changelog.d/5237.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Store key validity time in the storage layer.
47 changes: 33 additions & 14 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from six import raise_from
from six.moves import urllib

import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
Expand All @@ -43,6 +42,7 @@
RequestSendFailed,
SynapseError,
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
Expand Down Expand Up @@ -307,11 +307,15 @@ def do_iterations():
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
fetch_key_result = result_keys.get(key_id)
if fetch_key_result:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key)
(
server_name,
key_id,
fetch_key_result.verify_key,
)
)
break
else:
Expand Down Expand Up @@ -348,12 +352,12 @@ def on_err(err):
def get_keys_from_store(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Returns:
Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
server_name -> key_id -> VerifyKey
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult
"""
keys_to_fetch = (
(server_name, key_id)
Expand Down Expand Up @@ -430,6 +434,18 @@ def get_keys_from_server(self, server_name_and_key_ids):
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
from server_name -> key_id -> FetchKeyResult
"""
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
Expand Down Expand Up @@ -506,7 +522,7 @@ def get_server_verify_key_v2_indirect(

@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, nacl.signing.VerifyKey]
keys = {} # type: dict[str, FetchKeyResult]

for requested_key_id in key_ids:
if requested_key_id in keys:
Expand Down Expand Up @@ -583,9 +599,9 @@ def process_v2_response(
actually in the response
Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]

# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
Expand All @@ -595,7 +611,9 @@ def process_v2_response(
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)

# TODO: improve this signature checking
server_name = response_json["server_name"]
Expand All @@ -606,15 +624,17 @@ def process_v2_response(
)

verify_signed_json(
response_json, server_name, verify_keys[key_id]
response_json, server_name, verify_keys[key_id].verify_key
)

for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = verify_key
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)

# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
Expand All @@ -623,7 +643,6 @@ def process_v2_response(
)

signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]

# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
Expand Down
31 changes: 21 additions & 10 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import six

import attr
from signedjson.key import decode_verify_key_bytes

from synapse.util import batch_iter
Expand All @@ -36,6 +37,12 @@
db_binary_type = memoryview


@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for


class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
Expand All @@ -54,8 +61,8 @@ def get_server_verify_keys(self, server_name_and_key_ids):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
map from (server_name, key_id) -> VerifyKey, or None if the key is
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
Expand All @@ -65,17 +72,19 @@ def _get_keys(txn, batch):

# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key FROM server_signature_keys "
"WHERE 1=0"
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)

txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

for row in txn:
server_name, key_id, key_bytes = row
keys[(server_name, key_id)] = decode_verify_key_bytes(
key_id, bytes(key_bytes)
server_name, key_id, key_bytes, ts_valid_until_ms = row
res = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
)
keys[(server_name, key_id)] = res

def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
Expand All @@ -89,20 +98,21 @@ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
Args:
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
key_values = []
value_values = []
invalidations = []
for server_name, key_id, verify_key in verify_keys:
for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
db_binary_type(verify_key.encode()),
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
)
# invalidate takes a tuple corresponding to the params of
Expand All @@ -125,6 +135,7 @@ def _invalidate(res):
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
Expand Down
23 changes: 23 additions & 0 deletions synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/* When we can use this key until, before we have to refresh it. */
ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;

UPDATE server_signature_keys SET ts_valid_until_ms = (
SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
skj.server_name = server_signature_keys.server_name AND
skj.key_id = server_signature_keys.key_id
);
22 changes: 14 additions & 8 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext

Expand Down Expand Up @@ -201,7 +202,7 @@ def test_verify_json_for_server(self):
(
"server9",
key1_id,
signedjson.key.get_verify_key(key1),
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
)
Expand Down Expand Up @@ -251,9 +252,10 @@ def get_json(destination, path, **kwargs):
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
Expand Down Expand Up @@ -321,9 +323,10 @@ def post_json(destination, path, data, **kwargs):
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
Expand All @@ -346,7 +349,10 @@ def post_json(destination, path, data, **kwargs):

@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx"):
with LoggingContext("testctx") as ctx:
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)

Expand Down
44 changes: 30 additions & 14 deletions tests/storage/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from twisted.internet.defer import Deferred

from synapse.storage.keys import FetchKeyResult

import tests.unittest

KEY_1 = signedjson.key.decode_verify_key_base64(
Expand All @@ -37,8 +39,8 @@ def test_get_server_verify_keys(self):
"from_server",
10,
[
("server1", key_id_1, KEY_1),
("server1", key_id_2, KEY_2),
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)
Expand All @@ -50,13 +52,15 @@ def test_get_server_verify_keys(self):

self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
self.assertEqual(res1, KEY_1)
self.assertEqual(res1.version, "key1")
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.verify_key.version, "key1")
self.assertEqual(res1.valid_until_ts, 100)

res2 = res[("server1", key_id_2)]
self.assertEqual(res2, KEY_2)
self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.version, "KEY_ID_2")
self.assertEqual(res2.verify_key.version, "KEY_ID_2")
self.assertEqual(res2.valid_until_ts, 200)

# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
Expand All @@ -73,35 +77,47 @@ def test_cache(self):
"from_server",
0,
[
("srv1", key_id_1, KEY_1),
("srv1", key_id_2, KEY_2),
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], KEY_2)

res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)

res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
self.assertEqual(res2.valid_until_ts, 200)

# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)

new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, new_key_2)]
"from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
)
self.get_success(d)

d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], new_key_2)

res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)

res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, new_key_2)
self.assertEqual(res2.valid_until_ts, 300)

0 comments on commit b75537b

Please sign in to comment.