Skip to content

Commit

Permalink
Add body_producer argument to httpclient.HTTPRequest.
Browse files Browse the repository at this point in the history
This allows for sending non-contiguous or asynchronously-produced
request bodies, including chunked encoding when the content-length
is not known in advance.
  • Loading branch information
bdarnell committed Mar 29, 2014
1 parent e63b4cb commit 681e51b
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 32 deletions.
61 changes: 37 additions & 24 deletions tornado/http1connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class HTTP1Connection(object):
We parse HTTP headers and bodies, and execute the request callback
until the HTTP conection is closed.
"""
def __init__(self, stream, address, no_keep_alive=False, protocol=None):
def __init__(self, stream, address, is_client,
no_keep_alive=False, protocol=None):
self.is_client = is_client
self.stream = stream
self.address = address
# Save the socket's address family now so we know how to
Expand Down Expand Up @@ -83,7 +85,7 @@ def _server_request_loop(self, delegate, gzip=False):
if gzip:
request_delegate = _GzipMessageDelegate(request_delegate)
try:
ret = yield self._read_message(request_delegate, False)
ret = yield self._read_message(request_delegate)
except iostream.StreamClosedError:
self.close()
return
Expand All @@ -93,18 +95,18 @@ def _server_request_loop(self, delegate, gzip=False):
def read_response(self, delegate, method, use_gzip=False):
if use_gzip:
delegate = _GzipMessageDelegate(delegate)
return self._read_message(delegate, True, method=method)
return self._read_message(delegate, method=method)

@gen.coroutine
def _read_message(self, delegate, is_client, method=None):
def _read_message(self, delegate, method=None):
assert isinstance(delegate, httputil.HTTPMessageDelegate)
self.message_delegate = delegate
try:
header_data = yield self.stream.read_until_regex(b"\r?\n\r?\n")
self._reading = True
self._finish_future = Future()
start_line, headers = self._parse_headers(header_data)
if is_client:
if self.is_client:
start_line = httputil.parse_response_start_line(start_line)
else:
start_line = httputil.parse_request_start_line(start_line)
Expand All @@ -120,7 +122,7 @@ def _read_message(self, delegate, is_client, method=None):
# TODO: where else do we need to check for detach?
raise gen.Return(False)
skip_body = False
if is_client:
if self.is_client:
if method == 'HEAD':
skip_body = True
code = start_line.code
Expand All @@ -130,12 +132,12 @@ def _read_message(self, delegate, is_client, method=None):
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
yield self._read_message(self.message_delegate,
is_client, method=method)
method=method)
else:
if headers.get("Expect") == "100-continue":
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(is_client, headers)
body_future = self._read_body(headers)
if body_future is not None:
yield body_future
self._reading = False
Expand Down Expand Up @@ -194,19 +196,29 @@ def detach(self):
self.stream = None
return stream

def write_headers(self, start_line, headers, chunk=None, callback=None):
self._chunking = (
# TODO: should this use self._version or start_line.version?
self._version == 'HTTP/1.1' and
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding.
# headers.
start_line.code != 304 and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
'Transfer-Encoding' not in headers)
def write_headers(self, start_line, headers, chunk=None, callback=None,
has_body=True):
if self.is_client:
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking = (
has_body and
'Content-Length' not in headers and
'Transfer-Encoding' not in headers)
else:
self._chunking = (
has_body and
# TODO: should this use self._version or start_line.version?
self._version == 'HTTP/1.1' and
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding.
# headers.
start_line.code != 304 and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
'Transfer-Encoding' not in headers)
if self._chunking:
headers['Transfer-Encoding'] = 'chunked'
lines = [utf8("%s %s %s" % start_line)]
Expand Down Expand Up @@ -293,7 +305,8 @@ def _finish_request(self):
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
self._finish_future.set_result(None)
if self._finish_future is not None:
self._finish_future.set_result(None)

def _parse_headers(self, data):
data = native_str(data.decode('latin1'))
Expand All @@ -307,7 +320,7 @@ def _parse_headers(self, data):
data[eol:100])
return start_line, headers

def _read_body(self, is_client, headers):
def _read_body(self, headers):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
Expand All @@ -316,7 +329,7 @@ def _read_body(self, is_client, headers):
return self._read_fixed_body(content_length)
if headers.get("Transfer-Encoding") == "chunked":
return self._read_chunked_body()
if is_client:
if self.is_client:
return self._read_body_until_close()
return None

Expand Down
22 changes: 20 additions & 2 deletions tornado/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,23 @@ def __init__(self, url, method="GET", headers=None, body=None,
proxy_password=None, allow_nonstandard_methods=None,
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None):
client_key=None, client_cert=None, body_producer=None):
r"""All parameters except ``url`` are optional.
:arg string url: URL to fetch
:arg string method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:arg body: HTTP body to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
:arg body_producer: Callable used for lazy/asynchronous request bodies.
TODO: document the interface.
Only one of ``body`` and ``body_producer`` may
be specified. ``body_producer`` is not supported on
``curl_httpclient``. When using ``body_producer`` it is recommended
to pass a ``Content-Length`` in the headers as otherwise chunked
encoding will be used, and many servers do not support chunked
encoding on requests.
:arg string auth_username: Username for HTTP authentication
:arg string auth_password: Password for HTTP authentication
:arg string auth_mode: Authentication mode; default is "basic".
Expand Down Expand Up @@ -348,6 +357,7 @@ def __init__(self, url, method="GET", headers=None, body=None,
self.url = url
self.method = method
self.body = body
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
Expand Down Expand Up @@ -388,6 +398,14 @@ def body(self):
def body(self, value):
self._body = utf8(value)

@property
def body_producer(self):
return self._body_producer

@body_producer.setter
def body_producer(self, value):
self._body_producer = stack_context.wrap(value)

@property
def streaming_callback(self):
return self._streaming_callback
Expand Down
2 changes: 1 addition & 1 deletion tornado/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, request_callback, no_keep_alive=False, io_loop=None,
**kwargs)

def handle_stream(self, stream, address):
conn = HTTP1Connection(stream, address=address,
conn = HTTP1Connection(stream, address=address, is_client=False,
no_keep_alive=self.no_keep_alive,
protocol=self.protocol)
conn.start_serving(self, gzip=self.gzip)
Expand Down
30 changes: 26 additions & 4 deletions tornado/simple_httpclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement

from tornado.concurrent import is_future
from tornado.escape import utf8, _unicode
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
from tornado import httputil
Expand Down Expand Up @@ -303,16 +304,20 @@ def _on_connect(self):
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
if self.request.method in ("POST", "PATCH", "PUT"):
if self.request.body is None:
if (self.request.body is None and
self.request.body_producer is None):
raise AssertionError(
'Body must not be empty for "%s" request'
% self.request.method)
else:
if self.request.body is not None:
if (self.request.body is not None or
self.request.body_producer is not None):
raise AssertionError(
'Body must be empty for "%s" request'
% self.request.method)
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
self.request.headers["Content-Length"] = str(len(
self.request.body))
if (self.request.method == "POST" and
Expand All @@ -324,13 +329,30 @@ def _on_connect(self):
(('?' + self.parsed.query) if self.parsed.query else ''))
self.stream.set_nodelay(True)
self.connection = HTTP1Connection(
self.stream, self._sockaddr,
self.stream, self._sockaddr, is_client=True,
no_keep_alive=True, protocol=self.parsed.scheme)
start_line = httputil.RequestStartLine(self.request.method,
req_path, 'HTTP/1.1')
self.connection.write_headers(start_line, self.request.headers)
self.connection.write_headers(
start_line, self.request.headers,
has_body=(self.request.body is not None or
self.request.body_producer is not None))
if self.request.body is not None:
self.connection.write(self.request.body)
self.connection.finish()
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if is_future(fut):
def on_body_written(fut):
fut.result()
self.connection.finish()
self._read_response()
self.io_loop.add_future(fut, on_body_written)
return
self.connection.finish()
self._read_response()

def _read_response(self):
# Ensure that any exception raised in read_response ends up in our
# stack context.
self.io_loop.add_future(
Expand Down
2 changes: 1 addition & 1 deletion tornado/test/httpserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def data_received(self, chunk):

def finish(self):
callback(b''.join(chunks))
conn = HTTP1Connection(stream, None)
conn = HTTP1Connection(stream, None, is_client=True)
conn.read_response(Delegate(), method='GET')


Expand Down
44 changes: 44 additions & 0 deletions tornado/test/simple_httpclient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def get(self):
stream.close()


class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)


class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
Expand All @@ -126,6 +131,7 @@ def get_app(self):
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
], gzip=True)

def test_singleton(self):
Expand Down Expand Up @@ -331,6 +337,44 @@ def test_no_content_length(self):
response = self.fetch("/no_content_length")
self.assertEquals(b"hello", response.body)

def sync_body_producer(self, write):
write(b'1234')
write(b'5678')

@gen.coroutine
def async_body_producer(self, write):
# TODO: write should return a Future.
# wrap it in simple_httpclient or change http1connection?
yield gen.Task(write, b'1234')
yield gen.Task(IOLoop.current().add_callback)
yield gen.Task(write, b'5678')

def test_sync_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")

def test_sync_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.sync_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")

def test_async_body_producer_chunked(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")

def test_async_body_producer_content_length(self):
response = self.fetch("/echo_post", method="POST",
body_producer=self.async_body_producer,
headers={'Content-Length': '8'})
response.rethrow()
self.assertEqual(response.body, b"12345678")


class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
Expand Down

0 comments on commit 681e51b

Please sign in to comment.