diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 648b9b9b..6ae6004b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -176,9 +176,15 @@ def _run_until_complete_cb(fut): class Server(events.AbstractServer): - def __init__(self, loop, sockets): + def __init__(self, loop, sockets, protocol_factory, ssl, backlog, *, + max_connections=None): self._loop = loop self.sockets = sockets + self._protocol_factory = protocol_factory + self._ssl = ssl + self._backlog = backlog + self._max_connections = max_connections + self._paused = False self._active_count = 0 self._waiters = [] @@ -188,14 +194,37 @@ def __repr__(self): def _attach(self): assert self.sockets is not None self._active_count += 1 + if self._max_connections is not None and \ + not self._paused and \ + self._active_count >= self._max_connections: + self.pause() def _detach(self): assert self._active_count > 0 self._active_count -= 1 if self._active_count == 0 and self.sockets is None: self._wakeup() + elif self._paused and self._max_connections is not None and \ + self._active_count < self._max_connections: + self.resume() + + def pause(self): + """Pause future calls to accept().""" + assert not self._paused + self._paused = True + for sock in self.sockets: + self._loop.remove_reader(sock.fileno()) + + def resume(self): + """Resume use of accept() on listening socket(s).""" + assert self._paused + self._paused = False + for sock in self.sockets: + self._loop._start_serving(self._protocol_factory, sock, self._ssl, + self, self._backlog) def close(self): + self._protocol_factory = None sockets = self.sockets if sockets is None: return @@ -943,7 +972,8 @@ def create_server(self, protocol_factory, host=None, port=None, backlog=100, ssl=None, reuse_address=None, - reuse_port=None): + reuse_port=None, + max_connections=None): """Create a TCP server. The host parameter can be a string, in that case the TCP server is bound @@ -1026,7 +1056,8 @@ def create_server(self, protocol_factory, host=None, port=None, raise ValueError('Neither host/port nor sock were specified') sockets = [sock] - server = Server(self, sockets) + server = Server(self, sockets, protocol_factory, ssl, backlog, + max_connections=max_connections) for sock in sockets: sock.listen(backlog) sock.setblocking(False) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 65b61db6..22b8e6fb 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -247,7 +247,8 @@ def create_unix_connection(self, protocol_factory, path, *, @coroutine def create_unix_server(self, protocol_factory, path=None, *, - sock=None, backlog=100, ssl=None): + sock=None, backlog=100, ssl=None, + max_connections=None): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -294,7 +295,8 @@ def create_unix_server(self, protocol_factory, path=None, *, 'A UNIX Domain Stream Socket was expected, got {!r}' .format(sock)) - server = base_events.Server(self, [sock]) + server = base_events.Server(self, [sock], protocol_factory, ssl, + backlog, max_connections=max_connections) sock.listen(backlog) sock.setblocking(False) self._start_serving(protocol_factory, sock, ssl, server) diff --git a/tests/test_events.py b/tests/test_events.py index 7df926f1..4b667f15 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1312,6 +1312,109 @@ def connection_made(self, transport): server.close() + def test_create_server_max_connections(self): + protos = [] + on_data = asyncio.Event(loop=self.loop) + + class MaxConnTestProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + protos.append(self) + def data_received(self, data): + super().data_received(data) + on_data.set() + + f = self.loop.create_server(lambda: MaxConnTestProto(loop=self.loop), + '0.0.0.0', 0, max_connections=2) + server = self.loop.run_until_complete(f) + port = server.sockets[0].getsockname()[1] + self._test_create_server_max_connections(server, socket.socket, + ('127.0.0.1', port), + protos, on_data) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_max_connections(self): + protos = [] + on_data = asyncio.Event(loop=self.loop) + + class MaxConnTestProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + protos.append(self) + def data_received(self, data): + super().data_received(data) + on_data.set() + + factory = lambda: MaxConnTestProto(loop=self.loop) + server, path = self._make_unix_server(factory, max_connections=2) + socket_factory = lambda: socket.socket(socket.AF_UNIX) + self._test_create_server_max_connections(server, socket_factory, path, + protos, on_data) + + def _test_create_server_max_connections(self, server, socket_factory, + connect_to, protos, on_data): + sock_fd = server.sockets[0].fileno() + + # Low water.. + c1 = socket_factory() + c1.connect(connect_to) + c1.sendall(b'x') + self.loop.run_until_complete(on_data.wait()) + on_data.clear() + self.assertFalse(server._paused) + self.loop._selector.get_key(sock_fd) # has reader + + # High water.. + c2 = socket_factory() + c2.connect(connect_to) + c2.sendall(b'x') + self.loop.run_until_complete(on_data.wait()) + on_data.clear() + self.assertEqual(server._active_count, 2) + self.assertTrue(server._paused) + self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd) + + # Low water again.. + p = protos.pop(0) + p.transport.close() + self.loop.run_until_complete(p.done) + self.assertFalse(server._paused) + self.loop._selector.get_key(sock_fd) # has reader + + # cleanup + p = protos.pop(0) + p.transport.close() + self.loop.run_until_complete(p.done) + c1.close() + c2.close() + server.close() + self.assertFalse(protos) + + def test_create_server_pause_resume(self): + f = self.loop.create_server(lambda: None, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock_fd = server.sockets[0].fileno() + self._test_create_server_pause_resume(server, sock_fd) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_pause_resume(self): + server, path = self._make_unix_server(lambda: None) + sock_fd = server.sockets[0].fileno() + self._test_create_server_pause_resume(server, sock_fd) + + def _test_create_server_pause_resume(self, server, sock_fd): + server.pause() + self.assertTrue(server._paused) + self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd) + self.assertRaises(AssertionError, server.pause) + + server.resume() + self.assertFalse(server._paused) + self.loop._selector.get_key(sock_fd) # has reader + self.assertRaises(AssertionError, server.resume) + + server.close() + def test_server_close(self): f = self.loop.create_server(MyProto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) @@ -2162,6 +2265,12 @@ def test_create_datagram_endpoint(self): def test_remove_fds_after_closing(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_create_server_max_connections(self): + raise unittest.SkipTest("IocpEventLoop incompatible with max_connections") + + def test_create_server_pause_resume(self): + raise unittest.SkipTest("IocpEventLoop incompatible with Server pause") else: from asyncio import selectors