Skip to content

Commit

Permalink
Refactor tests.util.setup_test_homeserver and `tests.server.setup_t…
Browse files Browse the repository at this point in the history
…est_homeserver`. (matrix-org#11503)
  • Loading branch information
reivilibre authored Dec 21, 2021
1 parent b610223 commit e6897e7
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 185 deletions.
1 change: 1 addition & 0 deletions changelog.d/11503.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`.
199 changes: 190 additions & 9 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import logging
import time
import uuid
import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
Expand All @@ -27,6 +30,7 @@
Type,
Union,
)
from unittest.mock import Mock

import attr
from typing_extensions import Deque
Expand All @@ -53,11 +57,24 @@
from twisted.web.resource import IResource
from twisted.web.server import Request, Site

from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.util import Clock

from tests.utils import setup_test_homeserver as _sth
from tests.utils import (
LEAVE_DB,
POSTGRES_BASE_DB,
POSTGRES_HOST,
POSTGRES_PASSWORD,
POSTGRES_USER,
USE_POSTGRES_FOR_TESTS,
MockClock,
default_config,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -450,14 +467,11 @@ def _(res):
return d


def setup_test_homeserver(cleanup_func, *args, **kwargs):
def _make_test_homeserver_synchronous(server: HomeServer) -> None:
"""
Set up a synchronous test server, driven by the reactor used by
the homeserver.
Make the given test homeserver's database interactions synchronous.
"""
server = _sth(cleanup_func, *args, **kwargs)

# Make the thread pool synchronous.
clock = server.get_clock()

for database in server.get_datastores().databases:
Expand Down Expand Up @@ -485,15 +499,14 @@ def runInteraction(interaction, *args, **kwargs):

pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
# Replace the thread pool with a threadless 'thread' pool
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True

# We've just changed the Databases to run DB transactions on the same
# thread, so we need to disable the dedicated thread behaviour.
server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False

return server


def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
Expand Down Expand Up @@ -673,3 +686,171 @@ def connect_client(
client.makeConnection(FakeTransport(server, reactor))

return client, server


class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore


def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
If no datastore is supplied, one is created and given to the homeserver.
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor

if config is None:
config = default_config(name, parse=True)

config.ldap_enabled = False

if "clock" not in kwargs:
kwargs["clock"] = MockClock()

if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex

database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"user": POSTGRES_USER,
"cp_min": 1,
"cp_max": 5,
},
}
else:
database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}

if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]

database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]

db_engine = create_engine(database.config)

# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
cur.execute(
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
)
cur.close()
db_conn.close()

hs = homeserver_to_use(
name,
config=config,
version_string="Synapse/tests",
reactor=reactor,
)

# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, "_" + key, val)

# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()

hs.setup()
if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()

if isinstance(db_engine, PostgresEngine):
database = hs.get_datastores().databases[0]

# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2

# Close all the db pools
database._db_pool.close()

dropped = False

# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()

# Try a few times to drop the DB. Some things may hold on to the
# database for a few more seconds due to flakiness, preventing
# us from dropping it when the test is over. If we can't drop
# it, warn and move on.
for _ in range(5):
try:
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
db_conn.commit()
dropped = True
except psycopg2.OperationalError as e:
warnings.warn(
"Couldn't drop old db: " + str(e), category=UserWarning
)
time.sleep(0.5)

cur.close()
db_conn.close()

if not dropped:
warnings.warn("Failed to drop old DB.", category=UserWarning)

if not LEAVE_DB:
# Register the cleanup hook
cleanup_func(cleanup)

# bcrypt is far too slow to be doing in unit tests
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()

hs.get_auth_handler().hash = hash

async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h

hs.get_auth_handler().validate_hash = validate_hash

# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)

return hs
3 changes: 2 additions & 1 deletion tests/storage/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from synapse.storage.engines import create_engine

from tests import unittest
from tests.utils import TestHomeServer, default_config
from tests.server import TestHomeServer
from tests.utils import default_config


class SQLBaseStoreTestCase(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test_roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from synapse.types import UserID, create_requester

from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection
from tests.utils import TestHomeServer


class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
Expand Down
Loading

0 comments on commit e6897e7

Please sign in to comment.