Skip to content

Commit

Permalink
Port storage/ to Python 3 (matrix-org#3725)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl authored Aug 30, 2018
1 parent 475253a commit 14e4d4f
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 36 deletions.
1 change: 1 addition & 0 deletions changelog.d/3725.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The synapse.storage module has been ported to Python 3.
2 changes: 1 addition & 1 deletion jenkins/prepare_synapse.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
$TOX_BIN/pip install 'pip>=10'

{ python synapse/python_dependencies.py
echo lxml psycopg2
echo lxml
} | xargs $TOX_BIN/pip install
3 changes: 3 additions & 0 deletions synapse/python_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
"affinity": {
"affinity": ["affinity"],
},
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
}


Expand Down
32 changes: 31 additions & 1 deletion synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import threading
import time

from six import iteritems, iterkeys, itervalues
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range

from canonicaljson import json
from prometheus_client import Histogram

from twisted.internet import defer
Expand Down Expand Up @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass


def db_to_json(db_content):
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()

# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
db_content = bytes(db_content)

# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode('utf8')

try:
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
2 changes: 1 addition & 1 deletion synapse/storage/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList

from ._base import Cache, SQLBaseStore
from ._base import Cache, SQLBaseStore, db_to_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -411,7 +411,7 @@ def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
Expand Down Expand Up @@ -466,7 +466,7 @@ def _get_cached_user_device(self, user_id, device_id):
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(json.loads(content))
defer.returnValue(db_to_json(content))

@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
Expand All @@ -479,7 +479,7 @@ def _get_cached_devices_for_user(self, user_id):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: json.loads(device["content"])
device["device_id"]: db_to_json(device["content"])
for device in devices
})

Expand Down Expand Up @@ -511,7 +511,7 @@ def _get_devices_with_keys_by_user_txn(self, txn, user_id):

key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems

from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json

from twisted.internet import defer

from synapse.util.caches.descriptors import cached

from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json


class EndToEndKeyStore(SQLBaseStore):
Expand Down Expand Up @@ -90,7 +90,7 @@ def get_e2e_device_keys(

for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
device_info["keys"] = db_to_json(device_info.pop("key_json"))

defer.returnValue(results)

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,18 @@ def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)

# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
cursor.execute("SET bytea_output TO escape")

# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
cursor.close()

cursor.close()

def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import OrderedDict, deque, namedtuple
from functools import wraps

from six import iteritems
from six import iteritems, text_type
from six.moves import range

from canonicaljson import json
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def event_dict(event):
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
Expand Down Expand Up @@ -1529,7 +1529,7 @@ def reindex_txn(txn):

contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
Expand Down Expand Up @@ -1910,9 +1910,9 @@ def _purge_history_txn(
(room_id,)
)
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
max_depth = max(row[1] for row in rows)

if max_depth <= token.topological:
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
Expand Down
9 changes: 5 additions & 4 deletions synapse/storage/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import logging
from collections import namedtuple
Expand Down Expand Up @@ -265,7 +266,7 @@ def _fetch_event_list(self, conn, event_list):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = zip(*event_list)[0]
event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
Expand Down Expand Up @@ -299,14 +300,14 @@ def fire(lst, res):
logger.exception("do_fetch")

# We only want to resolve deferreds from the main thread
def fire(evs):
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
d.errback(exc)

with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list)
self.hs.get_reactor().callFromThread(fire, event_list, e)

@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json

from twisted.internet import defer

from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks

from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json


class FilteringStore(SQLBaseStore):
Expand All @@ -44,7 +44,7 @@ def get_user_filter(self, user_localpart, filter_id):
desc="get_user_filter",
)

defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
defer.returnValue(db_to_json(def_json))

def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
Expand Down
14 changes: 10 additions & 4 deletions synapse/storage/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# limitations under the License.

import logging
import types

import six

from canonicaljson import encode_canonical_json, json

Expand All @@ -27,25 +28,30 @@

logger = logging.getLogger(__name__)

if six.PY2:
db_binary_type = buffer
else:
db_binary_type = memoryview


class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
r['data'] = None
try:
if isinstance(dataJson, types.BufferType):
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")

r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message,
r['id'], dataJson, e.args[0],
)
pass

if isinstance(r['pushkey'], types.BufferType):
if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")

return rows
Expand Down
7 changes: 4 additions & 3 deletions synapse/storage/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import six

from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json

from twisted.internet import defer

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached

from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json

# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
Expand Down Expand Up @@ -95,7 +95,8 @@ def _get_received_txn_response(self, txn, transaction_id, origin):
)

if result and result["response_code"]:
return result["response_code"], json.loads(str(result["response_json"]))
return result["response_code"], db_to_json(result["response_json"])

else:
return None

Expand Down
14 changes: 10 additions & 4 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,23 @@ def register(self, user_id):
self.assertEquals(200, code)
defer.returnValue(response)

@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
if body is None:
body = "body_text_here"

path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
content = {"msgtype": "m.text", "body": body}
if tok:
path = path + "?access_token=%s" % tok

(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
render(request, self.resource, self.hs.get_reactor())

assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)

return channel.json_body
Loading

0 comments on commit 14e4d4f

Please sign in to comment.