Skip to content

Commit

Permalink
Fix bug with on_connect for target hosts specified as a domain name
Browse files Browse the repository at this point in the history
Make 'on_connect' a property of Connection object when it's created, so
that it can be unambiguously associated with a specific connection
immediately after it's created.

Create a type definition for on_connect functions named OnConnectFunc.
  • Loading branch information
rostislav committed Nov 27, 2019
1 parent d4ddcb3 commit 373fd13
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
8 changes: 6 additions & 2 deletions src/server/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import random
import time
from asyncio import StreamReader, StreamWriter
from typing import Any, List, Optional
from typing import Any, AsyncGenerator, Callable, List, Optional

from src.server.outbound_message import Message, NodeType
from src.server.outbound_message import Message, NodeType, OutboundMessage
from src.types.peer_info import PeerInfo
from src.util import cbor
from src.util.ints import uint16
Expand All @@ -13,6 +13,8 @@
LENGTH_BYTES: int = 4
log = logging.getLogger(__name__)

OnConnectFunc = Optional[Callable[[], AsyncGenerator[OutboundMessage, None]]]


class Connection:
"""
Expand All @@ -28,6 +30,7 @@ def __init__(
sr: StreamReader,
sw: StreamWriter,
server_port: int,
on_connect: OnConnectFunc,
):
self.local_type = local_type
self.connection_type = connection_type
Expand All @@ -41,6 +44,7 @@ def __init__(
self.peer_port = self.writer.get_extra_info("peername")[1]
self.peer_server_port: Optional[int] = None
self.node_id = None
self.on_connect = on_connect

# Connection metrics
self.creation_type = time.time()
Expand Down
52 changes: 22 additions & 30 deletions src/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import time
import os
from yaml import safe_load
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, List, Optional, Tuple
from aiter import aiter_forker, iter_to_aiter, join_aiters, map_aiter, push_aiter
from aiter.server import start_server_aiter
from definitions import ROOT_DIR
from src.protocols.shared_protocol import Handshake, HandshakeAck, protocol_version
from src.server.connection import Connection, PeerConnections
from src.server.connection import Connection, OnConnectFunc, PeerConnections
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.peer_info import PeerInfo
from src.util import partial_func
Expand Down Expand Up @@ -51,9 +51,8 @@ class ChiaServer:
# Aiter used to broadcase messages
_outbound_aiter: push_aiter

# These will get called after a handshake is performed
_on_connect_callbacks: Dict[PeerInfo, Callable] = {}
_on_connect_generic_callback: Optional[Callable] = None
# Called for inbound connections after successful handshake
_on_inbound_connect: OnConnectFunc = None

def __init__(self, port: int, api: Any, local_type: NodeType):
self._port = port # TCP port to identify our node
Expand All @@ -69,9 +68,7 @@ def __init__(self, port: int, api: Any, local_type: NodeType):
async def start_server(
self,
host: str,
on_connect: Optional[
Callable[[], AsyncGenerator[OutboundMessage, None]]
] = None,
on_connect: OnConnectFunc = None,
) -> bool:
"""
Launches a listening server on host and port specified, to connect to NodeType nodes. On each
Expand All @@ -86,12 +83,12 @@ async def start_server(
self._port, host=None, reuse_address=True
)
if on_connect is not None:
self._on_connect_generic_callback = on_connect
self._on_inbound_connect = on_connect

def add_connection_type(
srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter]
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
return (srw[0], srw[1])
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, None]:
return (srw[0], srw[1], None)

srwt_aiter = map_aiter(add_connection_type, aiter)

Expand All @@ -104,9 +101,7 @@ def add_connection_type(
async def start_client(
self,
target_node: PeerInfo,
on_connect: Optional[
Callable[[], AsyncGenerator[OutboundMessage, None]]
] = None,
on_connect: OnConnectFunc = None,
) -> bool:
"""
Tries to connect to the target node, adding one connection into the pipeline, if successful.
Expand Down Expand Up @@ -146,14 +141,13 @@ async def start_client(
)
self.global_connections.peers.remove(target_node)
return False
if on_connect is not None:
self._on_connect_callbacks[target_node] = on_connect
asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer)])))
asyncio.create_task(self._add_to_srwt_aiter(iter_to_aiter([(reader, writer, on_connect)])))
return True

async def _add_to_srwt_aiter(
self,
aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter], None],
aiter: AsyncGenerator[Tuple[asyncio.StreamReader, asyncio.StreamWriter,
OnConnectFunc], None],
):
"""
Adds all swrt from aiter into the instance variable srwt_aiter, adding them to the pipeline.
Expand Down Expand Up @@ -258,14 +252,16 @@ async def serve_forever():
return asyncio.get_running_loop().create_task(serve_forever())

async def stream_reader_writer_to_connection(
self, swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter], server_port: int
self,
swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc],
server_port: int
) -> Connection:
"""
Maps a pair of (StreamReader, StreamWriter) to a Connection object,
Maps a tuple of (StreamReader, StreamWriter, on_connect) to a Connection object,
which also stores the type of connection (str). It is also added to the global list.
"""
sr, sw = swrt
con = Connection(self._local_type, None, sr, sw, server_port)
sr, sw, on_connect = swrt
con = Connection(self._local_type, None, sr, sw, server_port, on_connect)

log.info(f"Connection with {con.get_peername()} established")
return con
Expand All @@ -276,14 +272,10 @@ async def connection_to_outbound(
"""
Async generator which calls the on_connect async generator method, and yields any outbound messages.
"""
peer = PeerInfo(connection.peer_host, connection.peer_port)
if peer in self._on_connect_callbacks:
on_connect = self._on_connect_callbacks[peer]
async for outbound_message in on_connect():
yield connection, outbound_message
if self._on_connect_generic_callback:
async for outbound_message in self._on_connect_generic_callback():
yield connection, outbound_message
for func in connection.on_connect, self._on_inbound_connect:
if func:
async for outbound_message in func():
yield connection, outbound_message

async def perform_handshake(
self, connection: Connection
Expand Down

0 comments on commit 373fd13

Please sign in to comment.