Skip to content

Commit

Permalink
Allow preconstructed HTTPRequest objects in websocket_connect.
Browse files Browse the repository at this point in the history
In particular this allows for headers to be passed in to simulate
browser authentication behavior.
  • Loading branch information
bdarnell committed Sep 8, 2013
1 parent 0352fe0 commit b5ec807
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
18 changes: 16 additions & 2 deletions tornado/test/websocket_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
Expand All @@ -18,6 +17,11 @@ def on_close(self):
self.close_future.set_result(None)


class HeaderHandler(WebSocketHandler):
def open(self):
self.write_message(self.request.headers.get('X-Test', ''))


class NonWebSocketHandler(RequestHandler):
def get(self):
self.write('ok')
Expand All @@ -29,6 +33,7 @@ def get_app(self):
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
('/header', HeaderHandler),
])

@gen_test
Expand Down Expand Up @@ -85,3 +90,12 @@ def test_websocket_close_buffered_data(self):
ws.write_message('world')
ws.stream.close()
yield self.close_future

@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
headers={'X-Test': 'hello'}))
response = yield ws.read_message()
self.assertEqual(response, 'hello')
11 changes: 9 additions & 2 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from tornado.concurrent import TracebackFuture
from tornado.escape import utf8, native_str
from tornado import httpclient
from tornado import httpclient, httputil
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
Expand Down Expand Up @@ -862,7 +862,14 @@ def websocket_connect(url, io_loop=None, callback=None, connect_timeout=None):
"""
if io_loop is None:
io_loop = IOLoop.current()
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
if isinstance(url, httpclient.HTTPRequest):
assert connect_timeout is None
request = url
# Copy and convert the headers dict/object (see comments in
# AsyncHTTPClient.fetch)
request.headers = httputil.HTTPHeaders(request.headers)
else:
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout)
request = httpclient._RequestProxy(
request, httpclient.HTTPRequest._DEFAULTS)
conn = WebSocketClientConnection(io_loop, request)
Expand Down

0 comments on commit b5ec807

Please sign in to comment.