Skip to content

Commit

Permalink
Run more of web_test in wsgi_test. Fix a bug with 304 in wsgi.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdarnell committed Sep 10, 2012
1 parent 070e08a commit 7610920
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 67 deletions.
136 changes: 77 additions & 59 deletions tornado/test/web_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,31 @@
import socket
import sys

wsgi_safe = []

class SimpleHandlerTestCase(AsyncHTTPTestCase):
class WebTestCase(AsyncHTTPTestCase):
"""Base class for web tests that also supports WSGI mode.
Override get_handlers and get_app_kwargs instead of get_app.
Append to wsgi_safe to have it run in wsgi_test as well.
"""
def get_app(self):
self.app = Application(self.get_handlers(), **self.get_app_kwargs())
return self.app

def get_handlers(self):
raise NotImplementedError()

def get_app_kwargs(self):
return {}

class SimpleHandlerTestCase(WebTestCase):
"""Simplified base class for tests that work with a single handler class.
To use, define a nested class named ``Handler``.
"""
def get_app(self):
return Application([('/', self.Handler)],
log_function=lambda x: None)
def get_handlers(self):
return [('/', self.Handler)]


class CookieTestRequestHandler(RequestHandler):
Expand Down Expand Up @@ -82,8 +98,8 @@ def test_arbitrary_bytes(self):
self.assertEqual(handler.get_secure_cookie('foo'), b('\xe9'))


class CookieTest(AsyncHTTPTestCase):
def get_app(self):
class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
def get(self):
# Try setting cookies with different argument types
Expand Down Expand Up @@ -117,13 +133,12 @@ def get(self):
# Attributes from the first call are not carried over.
self.set_cookie("a", "e")

return Application([
("/set", SetCookieHandler),
return [("/set", SetCookieHandler),
("/get", GetCookieHandler),
("/set_domain", SetCookieDomainHandler),
("/special_char", SetCookieSpecialCharHandler),
("/set_overwrite", SetCookieOverwriteHandler),
])
]

def test_set_cookie(self):
response = self.fetch("/set")
Expand Down Expand Up @@ -191,12 +206,12 @@ def get(self):
self.send_error(500)


class AuthRedirectTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/relative', AuthRedirectRequestHandler,
dict(login_url='/login')),
('/absolute', AuthRedirectRequestHandler,
dict(login_url='http://example.com/login'))])
class AuthRedirectTest(WebTestCase):
def get_handlers(self):
return [('/relative', AuthRedirectRequestHandler,
dict(login_url='/login')),
('/absolute', AuthRedirectRequestHandler,
dict(login_url='http://example.com/login'))]

def test_relative_auth_redirect(self):
self.http_client.fetch(self.get_url('/relative'), self.stop,
Expand Down Expand Up @@ -227,9 +242,9 @@ def on_connection_close(self):
self.test.on_connection_close()


class ConnectionCloseTest(AsyncHTTPTestCase):
def get_app(self):
return Application([('/', ConnectionCloseHandler, dict(test=self))])
class ConnectionCloseTest(WebTestCase):
def get_handlers(self):
return [('/', ConnectionCloseHandler, dict(test=self))]

def test_connection_close(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
Expand Down Expand Up @@ -272,12 +287,11 @@ def get(self, *path_args):
args=recursive_unicode(self.request.arguments)))


class RequestEncodingTest(AsyncHTTPTestCase):
def get_app(self):
return Application([
("/group/(.*)", EchoHandler),
class RequestEncodingTest(WebTestCase):
def get_handlers(self):
return [("/group/(.*)", EchoHandler),
("/slashes/([^/]*)/([^/]*)", EchoHandler),
])
]

def fetch_json(self, path):
return json_decode(self.fetch(path).body)
Expand Down Expand Up @@ -457,13 +471,9 @@ def get(self):


# This test is shared with wsgi_test.py
class WSGISafeWebTest(AsyncHTTPTestCase):
class WSGISafeWebTest(WebTestCase):
COOKIE_SECRET = "WebTest.COOKIE_SECRET"

def get_app(self):
self.app = Application(self.get_handlers(), **self.get_app_kwargs())
return self.app

def get_app_kwargs(self):
loader = DictLoader({
"linkify.html": "{% module linkify(message) %}",
Expand Down Expand Up @@ -604,15 +614,14 @@ def test_get_argument(self):
self.assertEqual(response.body, b(""))
response = self.fetch("/get_argument")
self.assertEqual(response.body, b("default"))
wsgi_safe.append(WSGISafeWebTest)


class NonWSGIWebTests(AsyncHTTPTestCase):
def get_app(self):
urls = [
("/flow_control", FlowControlHandler),
("/empty_flush", EmptyFlushCallbackHandler),
]
return Application(urls)
class NonWSGIWebTests(WebTestCase):
def get_handlers(self):
return [("/flow_control", FlowControlHandler),
("/empty_flush", EmptyFlushCallbackHandler),
]

def test_flow_control(self):
self.assertEqual(self.fetch("/flow_control").body, b("123"))
Expand All @@ -622,8 +631,8 @@ def test_empty_flush(self):
self.assertEqual(response.body, b("ok"))


class ErrorResponseTest(AsyncHTTPTestCase):
def get_app(self):
class ErrorResponseTest(WebTestCase):
def get_handlers(self):
class DefaultHandler(RequestHandler):
def get(self):
if self.get_argument("status", None):
Expand Down Expand Up @@ -665,12 +674,11 @@ def get(self):
def write_error(self, status_code, **kwargs):
raise Exception("exception in write_error")

return Application([
url("/default", DefaultHandler),
return [url("/default", DefaultHandler),
url("/write_error", WriteErrorHandler),
url("/get_error_html", GetErrorHtmlHandler),
url("/failed_write_error", FailedWriteErrorHandler),
])
]

def test_default(self):
with ExpectLog(app_log, "Uncaught exception"):
Expand Down Expand Up @@ -707,10 +715,10 @@ def test_failed_write_error(self):
response = self.fetch("/failed_write_error")
self.assertEqual(response.code, 500)
self.assertEqual(b(""), response.body)
wsgi_safe.append(ErrorResponseTest)


class StaticFileTest(AsyncHTTPTestCase):
def get_app(self):
class StaticFileTest(WebTestCase):
def get_handlers(self):
class StaticUrlHandler(RequestHandler):
def get(self, path):
self.write(self.static_url(path))
Expand Down Expand Up @@ -742,10 +750,13 @@ def get(self, path):
result = (check_override == -1 and check_regular == 0)
self.write(str(result))

return Application([('/static_url/(.*)', StaticUrlHandler),
('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
('/override_static_url/(.*)', OverrideStaticUrlHandler)],
static_path=os.path.join(os.path.dirname(__file__), 'static'))
return [('/static_url/(.*)', StaticUrlHandler),
('/abs_static_url/(.*)', AbsoluteStaticUrlHandler),
('/override_static_url/(.*)', OverrideStaticUrlHandler)]

def get_app_kwargs(self):
return dict(static_path=os.path.join(os.path.dirname(__file__),
'static'))

def test_static_files(self):
response = self.fetch('/robots.txt')
Expand Down Expand Up @@ -779,10 +790,10 @@ def test_static_304(self):
self.assertEqual(response2.code, 304)
self.assertTrue('Content-Length' not in response2.headers)
self.assertTrue('Last-Modified' not in response2.headers)
wsgi_safe.append(StaticFileTest)


class CustomStaticFileTest(AsyncHTTPTestCase):
def get_app(self):
class CustomStaticFileTest(WebTestCase):
def get_handlers(self):
class MyStaticFileHandler(StaticFileHandler):
def get(self, path):
path = self.parse_url_path(path)
Expand All @@ -809,35 +820,41 @@ class StaticUrlHandler(RequestHandler):
def get(self, path):
self.write(self.static_url(path))

return Application([("/static_url/(.*)", StaticUrlHandler)],
static_path="dummy",
static_handler_class=MyStaticFileHandler)
self.static_handler_class = MyStaticFileHandler

return [("/static_url/(.*)", StaticUrlHandler)]

def get_app_kwargs(self):
return dict(static_path="dummy",
static_handler_class=self.static_handler_class)

def test_serve(self):
response = self.fetch("/static/foo.42.txt")
self.assertEqual(response.body, b("bar"))

def test_static_url(self):
with ExpectLog(gen_log, "Could not open static file"):
with ExpectLog(gen_log, "Could not open static file", required=False):
response = self.fetch("/static_url/foo.txt")
self.assertEqual(response.body, b("/static/foo.42.txt"))
wsgi_safe.append(CustomStaticFileTest)


class NamedURLSpecGroupsTest(AsyncHTTPTestCase):
def get_app(self):
class NamedURLSpecGroupsTest(WebTestCase):
def get_handlers(self):
class EchoHandler(RequestHandler):
def get(self, path):
self.write(path)

return Application([("/str/(?P<path>.*)", EchoHandler),
(u"/unicode/(?P<path>.*)", EchoHandler)])
return [("/str/(?P<path>.*)", EchoHandler),
(u"/unicode/(?P<path>.*)", EchoHandler)]

def test_named_urlspec_groups(self):
response = self.fetch("/str/foo")
self.assertEqual(response.body, b("foo"))

response = self.fetch("/unicode/bar")
self.assertEqual(response.body, b("bar"))
wsgi_safe.append(NamedURLSpecGroupsTest)


class ClearHeaderTest(SimpleHandlerTestCase):
Expand All @@ -852,7 +869,7 @@ def test_clear_header(self):
response = self.fetch("/")
self.assertTrue("h1" not in response.headers)
self.assertEqual(response.headers["h2"], "bar")

wsgi_safe.append(ClearHeaderTest)

class Header304Test(SimpleHandlerTestCase):
class Handler(RequestHandler):
Expand All @@ -872,3 +889,4 @@ def test_304_headers(self):
self.assertTrue("Content-Language" not in response2.headers)
# Not an entity header, but should not be added to 304s by chunking
self.assertTrue("Transfer-Encoding" not in response2.headers)
wsgi_safe.append(Header304Test)
15 changes: 11 additions & 4 deletions tornado/test/wsgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ def get_app(self):
return WSGIContainer(validator(WSGIApplication(self.get_handlers())))


class WSGIWebTest(web_test.WSGISafeWebTest):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(), **self.get_app_kwargs())
return WSGIContainer(validator(self.app))
def wrap_web_tests():
result = {}
for cls in web_test.wsgi_safe:
class WSGIWrappedTest(cls):
def get_app(self):
self.app = WSGIApplication(self.get_handlers(),
**self.get_app_kwargs())
return WSGIContainer(validator(self.app))
result["WSGIWrapped_" + cls.__name__] = WSGIWrappedTest
return result
globals().update(wrap_web_tests())
9 changes: 5 additions & 4 deletions tornado/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,11 @@ def start_response(status, response_headers, exc_info=None):
headers = data["headers"]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
if "content-type" not in header_set:
headers.append(("Content-Type", "text/html; charset=UTF-8"))
if status_code != 304:
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
if "content-type" not in header_set:
headers.append(("Content-Type", "text/html; charset=UTF-8"))
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))

Expand Down

0 comments on commit 7610920

Please sign in to comment.