Skip to content

Commit

Permalink
Public API for duplex streams in MockWebServer (square#7595)
Browse files Browse the repository at this point in the history
* Public API for duplex streams in MockWebServer

I'm not 100% on the name 'Stream' for the source+sink pair. It's
tempting to use 'Socket', though I think that's an implementation name
and this is an abstraction that uses a different implementation.

I've chosen Stream specifically 'cause it's the word used in the
HTTP/2 spec. My biggest gripe with it is that it's bidirectional
in the HTTP/2 spec, but Java InputStream and OutputStream are not
bidirectional.

* Dump APIs for streams

* Don't include a Content-Length header for chunked bodies

* Convert MockWebServerTest to Kotlin (square#7596)

* Rename .java to .kt

* Convert Java to Kotlin

* Null isn't special for last-write wins

* Attempt to make NonCompletingRequestBody less flaky
  • Loading branch information
squarejesse authored Dec 31, 2022
1 parent c30d3ce commit 34bb125
Show file tree
Hide file tree
Showing 14 changed files with 1,000 additions and 938 deletions.
20 changes: 16 additions & 4 deletions mockwebserver/api/mockwebserver3.api
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public final class mockwebserver3/MockResponse {
public final fun getBody ()Lmockwebserver3/MockResponseBody;
public final fun getBodyDelayNanos ()J
public final fun getCode ()I
public final fun getDuplexResponseBody ()Lmockwebserver3/internal/duplex/DuplexResponseBody;
public final fun getHeaders ()Lokhttp3/Headers;
public final fun getHeadersDelayNanos ()J
public final fun getHttp2ErrorCode ()I
Expand All @@ -30,11 +29,11 @@ public final class mockwebserver3/MockResponse {
public final fun getSettings ()Lokhttp3/internal/http2/Settings;
public final fun getSocketPolicy ()Lmockwebserver3/SocketPolicy;
public final fun getStatus ()Ljava/lang/String;
public final fun getStreamHandler ()Lmockwebserver3/StreamHandler;
public final fun getThrottleBytesPerPeriod ()J
public final fun getThrottlePeriodNanos ()J
public final fun getTrailers ()Lokhttp3/Headers;
public final fun getWebSocketListener ()Lokhttp3/WebSocketListener;
public final fun isDuplex ()Z
public final fun newBuilder ()Lmockwebserver3/MockResponse$Builder;
public fun toString ()Ljava/lang/String;
}
Expand All @@ -48,7 +47,7 @@ public final class mockwebserver3/MockResponse$Builder : java/lang/Cloneable {
public final fun addInformationalResponse (Lmockwebserver3/MockResponse;)Lmockwebserver3/MockResponse$Builder;
public final fun addPush (Lmockwebserver3/PushPromise;)Lmockwebserver3/MockResponse$Builder;
public final fun body (Ljava/lang/String;)Lmockwebserver3/MockResponse$Builder;
public final fun body (Lmockwebserver3/internal/duplex/DuplexResponseBody;)Lmockwebserver3/MockResponse$Builder;
public final fun body (Lmockwebserver3/MockResponseBody;)Lmockwebserver3/MockResponse$Builder;
public final fun body (Lokio/Buffer;)Lmockwebserver3/MockResponse$Builder;
public final fun bodyDelay (JLjava/util/concurrent/TimeUnit;)Lmockwebserver3/MockResponse$Builder;
public final fun build ()Lmockwebserver3/MockResponse;
Expand All @@ -60,14 +59,14 @@ public final class mockwebserver3/MockResponse$Builder : java/lang/Cloneable {
public final fun code (I)Lmockwebserver3/MockResponse$Builder;
public final fun getBody ()Lmockwebserver3/MockResponseBody;
public final fun getCode ()I
public final fun getDuplexResponseBody ()Lmockwebserver3/internal/duplex/DuplexResponseBody;
public final fun getHttp2ErrorCode ()I
public final fun getInTunnel ()Z
public final fun getInformationalResponses ()Ljava/util/List;
public final fun getPushPromises ()Ljava/util/List;
public final fun getSettings ()Lokhttp3/internal/http2/Settings;
public final fun getSocketPolicy ()Lmockwebserver3/SocketPolicy;
public final fun getStatus ()Ljava/lang/String;
public final fun getStreamHandler ()Lmockwebserver3/StreamHandler;
public final fun getThrottleBytesPerPeriod ()J
public final fun getWebSocketListener ()Lokhttp3/WebSocketListener;
public final fun headers (Lokhttp3/Headers;)Lmockwebserver3/MockResponse$Builder;
Expand All @@ -81,9 +80,12 @@ public final class mockwebserver3/MockResponse$Builder : java/lang/Cloneable {
public final fun setHttp2ErrorCode (I)V
public final fun setSocketPolicy (Lmockwebserver3/SocketPolicy;)V
public final fun setStatus (Ljava/lang/String;)V
public final fun setStreamHandler (Lmockwebserver3/StreamHandler;)V
public final fun setWebSocketListener (Lokhttp3/WebSocketListener;)V
public final fun settings (Lokhttp3/internal/http2/Settings;)Lmockwebserver3/MockResponse$Builder;
public final fun socketPolicy (Lmockwebserver3/SocketPolicy;)Lmockwebserver3/MockResponse$Builder;
public final fun status (Ljava/lang/String;)Lmockwebserver3/MockResponse$Builder;
public final fun streamHandler (Lmockwebserver3/StreamHandler;)Lmockwebserver3/MockResponse$Builder;
public final fun throttleBody (JJLjava/util/concurrent/TimeUnit;)Lmockwebserver3/MockResponse$Builder;
public final fun trailers (Lokhttp3/Headers;)Lmockwebserver3/MockResponse$Builder;
public final fun webSocketUpgrade (Lokhttp3/WebSocketListener;)Lmockwebserver3/MockResponse$Builder;
Expand Down Expand Up @@ -198,3 +200,13 @@ public final class mockwebserver3/SocketPolicy : java/lang/Enum {
public static fun values ()[Lmockwebserver3/SocketPolicy;
}

public abstract interface class mockwebserver3/Stream {
public abstract fun cancel ()V
public abstract fun getRequestBody ()Lokio/BufferedSource;
public abstract fun getResponseBody ()Lokio/BufferedSink;
}

public abstract interface class mockwebserver3/StreamHandler {
public abstract fun handle (Lmockwebserver3/Stream;)V
}

69 changes: 43 additions & 26 deletions mockwebserver/src/main/kotlin/mockwebserver3/MockResponse.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package mockwebserver3

import java.util.concurrent.TimeUnit
import mockwebserver3.internal.duplex.DuplexResponseBody
import mockwebserver3.internal.toMockResponseBody
import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf
Expand Down Expand Up @@ -48,7 +47,10 @@ class MockResponse {
val headers: Headers
val trailers: Headers

// At most one of (body,webSocketListener,streamHandler) is non-null.
val body: MockResponseBody?
val webSocketListener: WebSocketListener?
val streamHandler: StreamHandler?

val inTunnel: Boolean
val informationalResponses: List<MockResponse>
Expand All @@ -72,12 +74,6 @@ class MockResponse {

val settings: Settings

val webSocketListener: WebSocketListener?
val duplexResponseBody: DuplexResponseBody?

val isDuplex: Boolean
get() = duplexResponseBody != null

@JvmOverloads
constructor(
code: Int = 200,
Expand All @@ -102,6 +98,8 @@ class MockResponse {
this.headers = builder.headers.build()
this.trailers = builder.trailers.build()
this.body = builder.body
this.streamHandler = builder.streamHandler
this.webSocketListener = builder.webSocketListener
this.inTunnel = builder.inTunnel
this.informationalResponses = builder.informationalResponses.toList()
this.throttleBytesPerPeriod = builder.throttleBytesPerPeriod
Expand All @@ -114,8 +112,6 @@ class MockResponse {
this.settings = Settings().apply {
merge(builder.settings)
}
this.webSocketListener = builder.webSocketListener
this.duplexResponseBody = builder.duplexResponseBody
}

fun newBuilder(): Builder = Builder(this)
Expand Down Expand Up @@ -152,7 +148,31 @@ class MockResponse {

internal var trailers: Headers.Builder

// At most one of (body,webSocketListener,streamHandler) is non-null.
private var body_: MockResponseBody? = null
private var streamHandler_: StreamHandler? = null
private var webSocketListener_: WebSocketListener? = null
var body: MockResponseBody?
get() = body_
set(value) {
body_ = value
streamHandler_ = null
webSocketListener_ = null
}
var streamHandler: StreamHandler?
get() = streamHandler_
set(value) {
streamHandler_ = value
body_ = null
webSocketListener_ = null
}
var webSocketListener: WebSocketListener?
get() = webSocketListener_
set(value) {
webSocketListener_ = value
body_ = null
streamHandler_ = null
}

var throttleBytesPerPeriod: Long
private set
Expand All @@ -170,16 +190,14 @@ class MockResponse {
val pushPromises: MutableList<PushPromise>

val settings: Settings
var webSocketListener: WebSocketListener?
private set
var duplexResponseBody: DuplexResponseBody?
private set

constructor() {
this.inTunnel = false
this.informationalResponses = mutableListOf()
this.status = "HTTP/1.1 200 OK"
this.body = null
this.body_ = null
this.streamHandler_ = null
this.webSocketListener_ = null
this.headers = Headers.Builder()
.add("Content-Length", "0")
this.trailers = Headers.Builder()
Expand All @@ -191,8 +209,6 @@ class MockResponse {
this.headersDelayNanos = 0L
this.pushPromises = mutableListOf()
this.settings = Settings()
this.webSocketListener = null
this.duplexResponseBody = null
}

internal constructor(mockResponse: MockResponse) {
Expand All @@ -201,7 +217,9 @@ class MockResponse {
this.status = mockResponse.status
this.headers = mockResponse.headers.newBuilder()
this.trailers = mockResponse.trailers.newBuilder()
this.body = mockResponse.body
this.body_ = mockResponse.body
this.streamHandler_ = mockResponse.streamHandler
this.webSocketListener_ = mockResponse.webSocketListener
this.throttleBytesPerPeriod = mockResponse.throttleBytesPerPeriod
this.throttlePeriodNanos = mockResponse.throttlePeriodNanos
this.socketPolicy = mockResponse.socketPolicy
Expand All @@ -212,8 +230,6 @@ class MockResponse {
this.settings = Settings().apply {
merge(mockResponse.settings)
}
this.webSocketListener = mockResponse.webSocketListener
this.duplexResponseBody = mockResponse.duplexResponseBody
}

fun code(code: Int) = apply {
Expand Down Expand Up @@ -269,16 +285,18 @@ class MockResponse {
headers.removeAll(name)
}

fun body(body: Buffer) = apply {
setHeader("Content-Length", body.size)
this.body = body.toMockResponseBody()
fun body(body: Buffer) = body(body.toMockResponseBody())

fun body(body: MockResponseBody) = apply {
setHeader("Content-Length", body.contentLength)
this.body = body
}

/** Sets the response body to the UTF-8 encoded bytes of [body]. */
fun body(body: String): Builder = body(Buffer().writeUtf8(body))

fun body(duplexResponseBody: DuplexResponseBody) = apply {
this.duplexResponseBody = duplexResponseBody
fun streamHandler(streamHandler: StreamHandler) = apply {
this.streamHandler = streamHandler
}

/**
Expand Down Expand Up @@ -367,13 +385,12 @@ class MockResponse {

/**
* Attempts to perform a web socket upgrade on the connection.
* This will overwrite any previously set status or body.
* This will overwrite any previously set status, body, or streamHandler.
*/
fun webSocketUpgrade(listener: WebSocketListener) = apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
body = null
webSocketListener = listener
}

Expand Down
16 changes: 8 additions & 8 deletions mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ import mockwebserver3.SocketPolicy.SHUTDOWN_SERVER_AFTER_RESPONSE
import mockwebserver3.SocketPolicy.STALL_SOCKET_AT_START
import mockwebserver3.internal.ThrottledSink
import mockwebserver3.internal.TriggerSink
import mockwebserver3.internal.duplex.DuplexResponseBody
import mockwebserver3.internal.duplex.RealStream
import mockwebserver3.internal.sleepNanos
import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf
Expand Down Expand Up @@ -986,7 +986,7 @@ class MockWebServer : Closeable {
val body = Buffer()
val requestLine = "$method $path HTTP/1.1"
var exception: IOException? = null
if (readBody && !peek.isDuplex && peek.socketPolicy != DO_NOT_READ_REQUEST_BODY) {
if (readBody && peek.streamHandler == null && peek.socketPolicy != DO_NOT_READ_REQUEST_BODY) {
try {
val contentLengthString = headers["content-length"]
val requestBodySink = body.withThrottlingAndSocketPolicy(
Expand Down Expand Up @@ -1040,9 +1040,10 @@ class MockWebServer : Closeable {
val bodyDelayNanos = response.bodyDelayNanos
val trailers = response.trailers
val body = response.body
val streamHandler = response.streamHandler
val outFinished = (body == null &&
response.pushPromises.isEmpty() &&
!response.isDuplex)
streamHandler == null)
val flushHeaders = body == null || bodyDelayNanos != 0L
require(!outFinished || trailers.size == 0) {
"unsupported: no body and non-empty trailers $trailers"
Expand All @@ -1066,9 +1067,8 @@ class MockWebServer : Closeable {
responseBodySink.use {
body.writeTo(responseBodySink)
}
} else if (response.isDuplex) {
val duplexResponseBody = response.duplexResponseBody!!
duplexResponseBody.onRequest(request, stream)
} else if (streamHandler != null) {
streamHandler.handle(RealStream(stream))
} else if (!outFinished) {
stream.close(ErrorCode.NO_ERROR, null)
}
Expand Down Expand Up @@ -1114,9 +1114,9 @@ class MockWebServer : Closeable {
MwsDuplexAccess.instance = object : MwsDuplexAccess() {
override fun setBody(
mockResponseBuilder: MockResponse.Builder,
duplexResponseBody: DuplexResponseBody,
duplexResponseBody: StreamHandler,
) {
mockResponseBuilder.body(duplexResponseBody)
mockResponseBuilder.streamHandler(duplexResponseBody)
}
}
}
Expand Down
39 changes: 39 additions & 0 deletions mockwebserver/src/main/kotlin/mockwebserver3/Stream.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (C) 2022 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mockwebserver3

import okio.BufferedSink
import okio.BufferedSource

/**
* A bidirectional sequence of data frames exchanged between client and server.
*/
interface Stream {
val requestBody: BufferedSource
val responseBody: BufferedSink

/**
* Terminate the stream so that no further data is transmitted or received. Note that
* [requestBody] may return data after this call; that is the buffered data received before this
* stream was canceled.
*
* This does nothing if [requestBody] and [responseBody] are already closed.
*
* For HTTP/2 this sends the [CANCEL](https://datatracker.ietf.org/doc/html/rfc7540#section-7)
* error code.
*/
fun cancel()
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2018 Square, Inc.
* Copyright (C) 2022 Block, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,13 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mockwebserver3.internal.duplex
package mockwebserver3

import java.io.IOException
import okhttp3.internal.http2.Http2Stream
import mockwebserver3.RecordedRequest

fun interface DuplexResponseBody {
@Throws(IOException::class)
fun onRequest(request: RecordedRequest, http2Stream: Http2Stream)
/**
* Handles a call's stream directly. Use this instead of [MockResponseBody] to begin sending
* response data before all request data has been received.
*
* See [okhttp3.RequestBody.isDuplex].
*/
interface StreamHandler {
fun handle(stream: Stream)
}
Loading

0 comments on commit 34bb125

Please sign in to comment.