Skip to content

Commit

Permalink
Fix additional type hints from Twisted 21.2.0. (matrix-org#9591)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Mar 12, 2021
1 parent 1e67bff commit 55da8df
Show file tree
Hide file tree
Showing 18 changed files with 187 additions and 119 deletions.
1 change: 1 addition & 0 deletions changelog.d/9591.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix incorrect type hints.
2 changes: 1 addition & 1 deletion synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_public_keys(self, invite_event):

async def get_user_by_req(
self,
request: Request,
request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
Expand Down
8 changes: 5 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,9 @@ def __init__(self, hs: "HomeServer"):
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
self.query_handlers = (
{}
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]

# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
Expand Down Expand Up @@ -914,7 +916,7 @@ def register_edu_handler(
self.edu_handlers[edu_type] = handler

def register_query_handler(
self, query_type: str, handler: Callable[[dict], defer.Deferred]
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
Expand Down Expand Up @@ -987,7 +989,7 @@ async def on_edu(self, edu_type: str, origin: str, content: dict):
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)

async def on_query(self, query_type: str, args: dict):
async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
Expand Down
9 changes: 5 additions & 4 deletions synapse/handlers/oidc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from typing_extensions import TypedDict

from twisted.web.client import readBody
from twisted.web.http_headers import Headers

from synapse.config import ConfigError
from synapse.config.oidc_config import (
Expand Down Expand Up @@ -538,7 +539,7 @@ async def _exchange_code(self, code: str) -> Token:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
headers = {
raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
Expand All @@ -552,10 +553,10 @@ async def _exchange_code(self, code: str) -> Token:
body = urlencode(args, True)

# Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
headers = {k: [v] for (k, v) in headers.items()}
headers = Headers({k: [v] for (k, v) in raw_headers.items()})

# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
Expand Down
9 changes: 8 additions & 1 deletion synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
from twisted.web.iweb import (
UNKNOWN_LENGTH,
IAgent,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)

from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
Expand Down Expand Up @@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")


@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
Expand Down
23 changes: 15 additions & 8 deletions synapse/logging/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure

logger = logging.getLogger(__name__)
Expand All @@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""

transport = attr.ib(type=ITransport)
# This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
Expand Down Expand Up @@ -149,8 +152,6 @@ def _connect(self) -> None:
if self._connection_waiter:
return

self._connection_waiter = self._service.whenConnected(failAfterFailures=1)

def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
Expand All @@ -163,9 +164,13 @@ def fail(failure: Failure) -> None:
self._connect()

def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more
# generic ITransport.
transport = result.transport # type: Connection # type: ignore

# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and result.transport is self._producer.transport:
if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
Expand All @@ -177,14 +182,16 @@ def writer(result: Protocol) -> None:
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=result.transport,
transport=transport,
format=self.format,
)
result.transport.registerProducer(self._producer, True)
transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None

self._connection_waiter.addCallbacks(writer, fail)
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred

def _handle_pressure(self) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/push/emailpusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional

from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer

self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False

Expand Down
44 changes: 24 additions & 20 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
Expand Down Expand Up @@ -82,7 +82,7 @@

# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]


Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(self, hs: "HomeServer"):

# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
self._connections = [] # type: List[IReplicationConnection]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
Expand All @@ -197,7 +197,7 @@ def __init__(self, hs: "HomeServer"):

# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]

LaterGauge(
"synapse_replication_tcp_command_queue",
Expand All @@ -220,7 +220,7 @@ def __init__(self, hs: "HomeServer"):
self._server_notices_sender = hs.get_server_notices_sender()

def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Expand Down Expand Up @@ -267,7 +267,7 @@ async def _unsafe_process_queue(self, stream_name: str):
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
Expand Down Expand Up @@ -321,10 +321,10 @@ def get_streams_to_replicate(self) -> List[Stream]:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate

def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)

def send_positions_to_connection(self, conn: AbstractConnection):
def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
Expand All @@ -347,7 +347,7 @@ def send_positions_to_connection(self, conn: AbstractConnection):
)

def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()

Expand All @@ -359,21 +359,23 @@ def on_USER_SYNC(
return None

def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None

def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc()

if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)

def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()

Expand All @@ -395,7 +397,7 @@ async def _handle_user_ip(self, cmd: UserIpCommand):
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)

def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
Expand All @@ -412,7 +414,7 @@ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
self._add_command_to_stream_queue(conn, cmd)

async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
Expand Down Expand Up @@ -486,7 +488,7 @@ async def on_rdata(
stream_name, instance_name, token, rows
)

def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
Expand All @@ -496,7 +498,7 @@ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
self._add_command_to_stream_queue(conn, cmd)

async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
Expand Down Expand Up @@ -553,7 +555,9 @@ async def _process_position(

self._streams_by_connection.setdefault(conn, set()).add(stream_name)

def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)

Expand All @@ -576,7 +580,7 @@ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpComma
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)

def new_connection(self, connection: AbstractConnection):
def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)

Expand All @@ -603,7 +607,7 @@ def new_connection(self, connection: AbstractConnection):
UserSyncCommand(self._instance_id, user_id, True, now)
)

def lost_connection(self, connection: AbstractConnection):
def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
Expand All @@ -624,7 +628,7 @@ def connected(self) -> bool:
return bool(self._connections)

def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
Expand Down
Loading

0 comments on commit 55da8df

Please sign in to comment.