Skip to content

Commit

Permalink
=htc akka#15799 implement client-side Expect: 100-continue support
Browse files Browse the repository at this point in the history
  • Loading branch information
sirthias committed Mar 9, 2016
1 parent 47925e1 commit e3ee285
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ package akka.http.impl.engine.client

import akka.NotUsed
import akka.http.scaladsl.settings.{ ClientConnectionSettings, ParserSettings }
import akka.stream.impl.ConstantFun
import language.existentials
import scala.annotation.tailrec
import scala.concurrent.Promise
import scala.collection.mutable.ListBuffer
import akka.stream.TLSProtocol._
import akka.util.ByteString
import akka.event.LoggingAdapter
import akka.stream._
import akka.stream.scaladsl._
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.Host
import akka.http.scaladsl.model.{ IllegalResponseException, HttpMethod, HttpRequest, HttpResponse, ResponseEntity }
import akka.http.scaladsl.model.headers
import akka.http.scaladsl.model.{ IllegalResponseException, HttpRequest, HttpResponse, ResponseEntity }
import akka.http.impl.engine.rendering.{ RequestRenderingContext, HttpRequestRendererFactory }
import akka.http.impl.engine.parsing._
import akka.http.impl.util._
Expand All @@ -28,6 +30,9 @@ import akka.stream.stage.{ InHandler, OutHandler }
* INTERNAL API
*/
private[http] object OutgoingConnectionBlueprint {

type BypassData = HttpResponseParser.ResponseContext

/*
Stream Setup
============
Expand All @@ -38,62 +43,75 @@ private[http] object OutgoingConnectionBlueprint {
+-------------------------------------->| Merge | |
| Termination Backchannel | +----------+ | TCP-
| | | level
| | Method | client
| +------------+ | Bypass | flow
| | BypassData | client
| +------------+ | | flow
responseOut | responsePrep | Response |<---+ |
<------------+----------------| Parsing | |
| Merge |<------------------------------------------ V
+------------+
*/
def apply(hostHeader: Host,
def apply(hostHeader: headers.Host,
settings: ClientConnectionSettings,
log: LoggingAdapter): Http.ClientLayer = {
import settings._

// the initial header parser we initially use for every connection,
// will not be mutated, all "shared copy" parsers copy on first-write into the header cache
val rootParser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings) { info
if (parserSettings.illegalHeaderWarnings)
logParsingError(info withSummaryPrepended "Illegal response header", log, parserSettings.errorLoggingVerbosity)
})
val core = BidiFlow.fromGraph(GraphDSL.create() { implicit b
import GraphDSL.Implicits._

val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
val renderingContextCreation = b.add {
Flow[HttpRequest] map { request
val sendEntityTrigger =
request.headers collectFirst { case headers.Expect.`100-continue` Promise[NotUsed]().future }
RequestRenderingContext(request, hostHeader, sendEntityTrigger)
}
}

val requestRendering: Flow[HttpRequest, ByteString, NotUsed] = Flow[HttpRequest]
.map(RequestRenderingContext(_, hostHeader))
.via(Flow[RequestRenderingContext].flatMapConcat(requestRendererFactory.renderToSource).named("renderer"))
val bypassFanout = b.add(Broadcast[RequestRenderingContext](2, eagerCancel = true))

val methodBypass = Flow[HttpRequest].map(_.method)
val terminationMerge = b.add(TerminationMerge)

import ParserOutput._
val responsePrep = Flow[List[ResponseOutput]]
.mapConcat(conforms)
.via(new PrepareResponse(parserSettings))
val requestRendering: Flow[RequestRenderingContext, ByteString, NotUsed] = {
val requestRendererFactory = new HttpRequestRendererFactory(userAgentHeader, requestHeaderSizeHint, log)
Flow[RequestRenderingContext].flatMapConcat(requestRendererFactory.renderToSource).named("renderer")
}

val core = BidiFlow.fromGraph(GraphDSL.create() { implicit b
import GraphDSL.Implicits._
val methodBypassFanout = b.add(Broadcast[HttpRequest](2, eagerCancel = true))
val responseParsingMerge = b.add(new ResponseParsingMerge(rootParser))
val bypass = Flow[RequestRenderingContext] map { ctx
HttpResponseParser.ResponseContext(ctx.request.method, ctx.sendEntityTrigger.map(_.asInstanceOf[Promise[Unit]]))
}

val responseParsingMerge = b.add {
// the initial header parser we initially use for every connection,
// will not be mutated, all "shared copy" parsers copy on first-write into the header cache
val rootParser = new HttpResponseParser(parserSettings, HttpHeaderParser(parserSettings) { info
if (parserSettings.illegalHeaderWarnings)
logParsingError(info withSummaryPrepended "Illegal response header", log, parserSettings.errorLoggingVerbosity)
})
new ResponseParsingMerge(rootParser)
}

val responsePrep = Flow[List[ParserOutput.ResponseOutput]]
.mapConcat(ConstantFun.scalaIdentityFunction)
.via(new PrepareResponse(parserSettings))

val terminationFanout = b.add(Broadcast[HttpResponse](2))
val terminationMerge = b.add(TerminationMerge)

val logger = b.add(MapError[ByteString] { case t log.error(t, "Outgoing request stream error"); t }.named("errorLogger"))
val wrapTls = b.add(Flow[ByteString].map(SendBytes))
terminationMerge.out ~> requestRendering ~> logger ~> wrapTls

val collectSessionBytes = b.add(Flow[SslTlsInbound].collect { case s: SessionBytes s })
collectSessionBytes ~> responseParsingMerge.in0

methodBypassFanout.out(0) ~> terminationMerge.in0
renderingContextCreation.out ~> bypassFanout.in
bypassFanout.out(0) ~> terminationMerge.in0
terminationMerge.out ~> requestRendering ~> logger ~> wrapTls

methodBypassFanout.out(1) ~> methodBypass ~> responseParsingMerge.in1
bypassFanout.out(1) ~> bypass ~> responseParsingMerge.in1
collectSessionBytes ~> responseParsingMerge.in0

responseParsingMerge.out ~> responsePrep ~> terminationFanout.in
terminationFanout.out(0) ~> terminationMerge.in1

BidiShape(
methodBypassFanout.in,
renderingContextCreation.in,
wrapTls.out,
collectSessionBytes.in,
terminationFanout.out(1))
Expand All @@ -104,10 +122,10 @@ private[http] object OutgoingConnectionBlueprint {

// a simple merge stage that simply forwards its first input and ignores its second input
// (the terminationBackchannelInput), but applies a special completion handling
private object TerminationMerge extends GraphStage[FanInShape2[HttpRequest, HttpResponse, HttpRequest]] {
private val requests = Inlet[HttpRequest]("requests")
private object TerminationMerge extends GraphStage[FanInShape2[RequestRenderingContext, HttpResponse, RequestRenderingContext]] {
private val requests = Inlet[RequestRenderingContext]("requests")
private val responses = Inlet[HttpResponse]("responses")
private val out = Outlet[HttpRequest]("out")
private val out = Outlet[RequestRenderingContext]("out")

override def initialAttributes = Attributes.name("TerminationMerge")

Expand Down Expand Up @@ -162,9 +180,10 @@ private[http] object OutgoingConnectionBlueprint {
}

def onPush(): Unit = grab(in) match {
case ResponseStart(statusCode, protocol, headers, entityCreator, _)
case ResponseStart(statusCode, protocol, headers, entityCreator, closeRequested)
val entity = createEntity(entityCreator) withSizeLimit parserSettings.maxContentLength
push(out, HttpResponse(statusCode, headers, entity, protocol))
if (closeRequested) completeStage()

case MessageStartError(_, info)
throw IllegalResponseException(info)
Expand Down Expand Up @@ -259,25 +278,26 @@ private[http] object OutgoingConnectionBlueprint {
* 2. Read from the dataInput until exactly one response has been fully received
* 3. Go back to 1.
*/
class ResponseParsingMerge(rootParser: HttpResponseParser) extends GraphStage[FanInShape2[SessionBytes, HttpMethod, List[ResponseOutput]]] {
private class ResponseParsingMerge(rootParser: HttpResponseParser)
extends GraphStage[FanInShape2[SessionBytes, BypassData, List[ResponseOutput]]] {
private val dataInput = Inlet[SessionBytes]("data")
private val methodBypassInput = Inlet[HttpMethod]("method")
private val bypassInput = Inlet[BypassData]("request")
private val out = Outlet[List[ResponseOutput]]("out")

override def initialAttributes = Attributes.name("ResponseParsingMerge")

val shape = new FanInShape2(dataInput, methodBypassInput, out)
val shape = new FanInShape2(dataInput, bypassInput, out)

override def createLogic(effectiveAttributes: Attributes) = new GraphStageLogic(shape) {
// each connection uses a single (private) response parser instance for all its responses
// which builds a cache of all header instances seen on that connection
val parser = rootParser.createShallowCopy()
var waitingForMethod = true

setHandler(methodBypassInput, new InHandler {
setHandler(bypassInput, new InHandler {
override def onPush(): Unit = {
val method = grab(methodBypassInput)
parser.setRequestMethodForNextResponse(method)
val responseContext = grab(bypassInput)
parser.setContextForNextResponse(responseContext)
val output = parser.parseBytes(ByteString.empty)
drainParser(output)
}
Expand Down Expand Up @@ -306,8 +326,8 @@ private[http] object OutgoingConnectionBlueprint {

val getNextMethod = () {
waitingForMethod = true
if (isClosed(methodBypassInput)) completeStage()
else pull(methodBypassInput)
if (isClosed(bypassInput)) completeStage()
else pull(bypassInput)
}

val getNextData = () {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package akka.http.impl.engine.parsing

import scala.annotation.tailrec
import scala.concurrent.Promise
import scala.util.control.NoStackTrace
import akka.http.scaladsl.settings.ParserSettings
import akka.http.impl.model.parser.CharacterClasses
import akka.util.ByteString
Expand All @@ -17,19 +19,20 @@ import ParserOutput._
*/
private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser: HttpHeaderParser)
extends HttpMessageParser[ResponseOutput](_settings, _headerParser) {
import HttpResponseParser._
import HttpMessageParser._
import settings._

private[this] var requestMethodForCurrentResponse: Option[HttpMethod] = None
private[this] var contextForCurrentResponse: Option[ResponseContext] = None
private[this] var statusCode: StatusCode = StatusCodes.OK

def createShallowCopy(): HttpResponseParser = new HttpResponseParser(settings, headerParser.createShallowCopy())

def setRequestMethodForNextResponse(requestMethod: HttpMethod): Unit =
if (requestMethodForCurrentResponse.isEmpty) requestMethodForCurrentResponse = Some(requestMethod)
def setContextForNextResponse(responseContext: ResponseContext): Unit =
if (contextForCurrentResponse.isEmpty) contextForCurrentResponse = Some(responseContext)

protected def parseMessage(input: ByteString, offset: Int): StateResult =
if (requestMethodForCurrentResponse.isDefined) {
if (contextForCurrentResponse.isDefined) {
var cursor = parseProtocol(input, offset)
if (byteChar(input, cursor) == ' ') {
cursor = parseStatus(input, cursor + 1)
Expand All @@ -41,7 +44,7 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser:
}

override def emit(output: ResponseOutput): Unit = {
if (output == MessageEnd) requestMethodForCurrentResponse = None
if (output == MessageEnd) contextForCurrentResponse = None
super.emit(output)
}

Expand Down Expand Up @@ -78,21 +81,47 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser:
} else badStatusCode
}

def handleInformationalResponses: Boolean = true

// http://tools.ietf.org/html/rfc7230#section-3.3
def parseEntity(headers: List[HttpHeader], protocol: HttpProtocol, input: ByteString, bodyStart: Int,
clh: Option[`Content-Length`], cth: Option[`Content-Type`], teh: Option[`Transfer-Encoding`],
expect100continue: Boolean, hostHeaderPresent: Boolean, closeAfterResponseCompletion: Boolean): StateResult = {

def emitResponseStart(createEntity: EntityCreator[ResponseOutput, ResponseEntity],
headers: List[HttpHeader] = headers) =
emit(ResponseStart(statusCode, protocol, headers, createEntity, closeAfterResponseCompletion))
def finishEmptyResponse() = {
emitResponseStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart)
headers: List[HttpHeader] = headers) = {
val close =
contextForCurrentResponse.get.oneHundredContinueTrigger match {
case None closeAfterResponseCompletion
case Some(trigger) if statusCode.isSuccess
trigger.trySuccess(())
closeAfterResponseCompletion
case Some(trigger)
trigger.tryFailure(OneHundredContinueError)
true
}
emit(ResponseStart(statusCode, protocol, headers, createEntity, close))
}

if (statusCode.allowsEntity && (requestMethodForCurrentResponse.get != HttpMethods.HEAD)) {
def finishEmptyResponse() =
statusCode match {
case _: StatusCodes.Informational if handleInformationalResponses
if (statusCode == StatusCodes.Continue)
contextForCurrentResponse.get.oneHundredContinueTrigger.foreach(_.trySuccess(()))

// http://tools.ietf.org/html/rfc7231#section-6.2 says:
// "A client MUST be able to parse one or more 1xx responses received prior to a final response,
// even if the client does not expect one."
// so we simply drop this interim response and start parsing the next one
startNewMessage(input, bodyStart)
case _
emitResponseStart(emptyEntity(cth))
setCompletionHandling(HttpMessageParser.CompletionOk)
emit(MessageEnd)
startNewMessage(input, bodyStart)
}

if (statusCode.allowsEntity && (contextForCurrentResponse.get.requestMethod != HttpMethods.HEAD)) {
teh match {
case None clh match {
case Some(`Content-Length`(contentLength))
Expand Down Expand Up @@ -137,4 +166,19 @@ private[http] class HttpResponseParser(_settings: ParserSettings, _headerParser:
emit(EntityPart(input.drop(bodyStart).compact))
continue(parseToCloseBody(_, _, newTotalBytes))
}
}

private[http] object HttpResponseParser {
/**
* @param requestMethod the request's HTTP method
* @param oneHundredContinueTrigger if the request contains an `Expect: 100-continue` header this option contains
* a promise whose completion either triggers the sending of the (suspended)
* request entity or the closing of the connection (for error completion)
*/
private[http] final case class ResponseContext(requestMethod: HttpMethod,
oneHundredContinueTrigger: Option[Promise[Unit]])

private[http] object OneHundredContinueError
extends RuntimeException("Received error response for request with `Expect: 100-continue` header")
with NoStackTrace
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

package akka.http.impl.engine.rendering

import akka.NotUsed
import akka.http.impl.engine.parsing.HttpResponseParser
import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.http.scaladsl.model.RequestEntityAcceptance._

import scala.concurrent.Future
import scala.annotation.tailrec
import akka.event.LoggingAdapter
import akka.util.ByteString
Expand Down Expand Up @@ -102,8 +105,16 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.`
def renderContentLength(contentLength: Long) =
if (method.isEntityAccepted && (contentLength > 0 || method.requestEntityAcceptance == Expected)) r ~~ `Content-Length` ~~ contentLength ~~ CrLf else r

def renderStreamed(body: Source[ByteString, Any]): RequestRenderingOutput =
RequestRenderingOutput.Streamed(renderByteStrings(r, body))
def renderStreamed(body: Source[ByteString, Any]): RequestRenderingOutput = {
val headerPart = Source.single(r.get)
val stream = ctx.sendEntityTrigger match {
case None headerPart ++ body
case Some(future)
val barrier = Source.fromFuture(future).drop(1).asInstanceOf[Source[ByteString, Any]]
(headerPart ++ barrier ++ body).recoverWith { case HttpResponseParser.OneHundredContinueError Source.empty }
}
RequestRenderingOutput.Streamed(stream)
}

def completeRequestRendering(): RequestRenderingOutput =
entity match {
Expand All @@ -113,7 +124,8 @@ private[http] class HttpRequestRendererFactory(userAgentHeader: Option[headers.`

case HttpEntity.Strict(_, data)
renderContentLength(data.length) ~~ CrLf
RequestRenderingOutput.Strict(r.get ++ data)
if (ctx.sendEntityTrigger.isDefined) renderStreamed(Source.single(data))
else RequestRenderingOutput.Strict(r.get ++ data)

case HttpEntity.Default(_, contentLength, data)
renderContentLength(contentLength) ~~ CrLf
Expand Down Expand Up @@ -155,5 +167,14 @@ private[http] object HttpRequestRendererFactory {

/**
* INTERNAL API
*
* @param request the request to be rendered
* @param hostHeader the host header to render (not necessarily contained in the request.headers)
* @param sendEntityTrigger defined when the request has a `Expect: 100-continue` header; in this case the future will
* be completed successfully when the request entity is allowed to go out onto the wire;
* if the future is completed with an error the connection is to be closed.
*/
private[http] final case class RequestRenderingContext(request: HttpRequest, hostHeader: Host)
private[http] final case class RequestRenderingContext(
request: HttpRequest,
hostHeader: Host,
sendEntityTrigger: Option[Future[NotUsed]] = None)
Loading

0 comments on commit e3ee285

Please sign in to comment.