Skip to content

Commit

Permalink
Merge pull request tornadoweb#1935 from mirceaulinic/SOURCE-IP
Browse files Browse the repository at this point in the history
TCPClient: connect using specific source IP address and port
  • Loading branch information
bdarnell authored Feb 10, 2017
2 parents f8aab76 + 1e39fd7 commit e704489
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
38 changes: 32 additions & 6 deletions tornado/tcpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,27 @@ def close(self):

@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
max_buffer_size=None):
max_buffer_size=None, source_ip=None, source_port=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
Using the ``source_ip`` kwarg, one can specify the source
IP address to use when establishing the connection.
In case the user needs to resolve and
use a specific interface, it has to be handled outside
of Tornado as this depends very much on the platform.
Similarly, when the user requires a certain source port, it can
be specified using the ``source_port`` arg.
"""
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo, self.io_loop,
functools.partial(self._create_stream, max_buffer_size))
functools.partial(self._create_stream, max_buffer_size,
source_ip=source_ip, source_port=source_port)
)
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
Expand All @@ -174,13 +185,28 @@ def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
server_hostname=host)
raise gen.Return(stream)

def _create_stream(self, max_buffer_size, af, addr):
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
source_port=None):
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
source_port_bind = source_port if isinstance(source_port, int) else 0
source_ip_bind = source_ip
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
socket_obj = socket.socket(af)
if source_port_bind or source_ip_bind:
# If the user requires binding also to a specific IP/port.
socket_obj.bind((source_ip_bind, source_port_bind))
# Fail loudly if unable to use the IP/port.
try:
stream = IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
stream = IOStream(socket_obj,
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
except socket.error as e:
fu = Future()
fu.set_exception(e)
Expand Down
35 changes: 32 additions & 3 deletions tornado/test/tcpclient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port
from tornado.test.util import skipIfNoIPv6, unittest, refusing_port, skipIfNonUnix

# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
Expand Down Expand Up @@ -81,9 +81,11 @@ def skipIfLocalhostV4(self):
self.skipTest("localhost does not resolve to ipv6")

@gen_test
def do_test_connect(self, family, host):
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
stream = yield self.client.connect(host, port)
stream = yield self.client.connect(host, port,
source_ip=source_ip,
source_port=source_port)
with closing(stream):
stream.write(b"hello")
data = yield self.server.streams[0].read_bytes(5)
Expand Down Expand Up @@ -125,6 +127,33 @@ def test_refused_ipv4(self):
with self.assertRaises(IOError):
yield self.client.connect('127.0.0.1', port)

def test_source_ip_fail(self):
'''
Fail when trying to use the source IP Address '8.8.8.8'.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_ip='8.8.8.8')

def test_source_ip_success(self):
'''
Success when trying to use the source IP Address '127.0.0.1'
'''
self.do_test_connect(socket.AF_INET, '127.0.0.1', source_ip='127.0.0.1')

@skipIfNonUnix
def test_source_port_fail(self):
'''
Fail when trying to use source port 1.
'''
self.assertRaises(socket.error,
self.do_test_connect,
socket.AF_INET,
'127.0.0.1',
source_port=1)


class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
Expand Down

0 comments on commit e704489

Please sign in to comment.