Skip to content

Commit

Permalink
Sanitize response handling and clear out queued requests on disconnect (
Browse files Browse the repository at this point in the history
livekit#387)

* Sanitize response handling and clear out queued requests upon disconnection

* cleanup test code

* spotless

* More sanitization

* Fix tests
  • Loading branch information
davidliu authored Mar 13, 2024
1 parent 730c8b8 commit d21a385
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,19 @@ constructor(
>? = null
private lateinit var coroutineScope: CloseableCoroutineScope

/**
* @see [startRequestQueue]
*/
private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)
private val requestFlowJobLock = Object()
private var requestFlowJob: Job? = null
private val requestFlow = MutableSharedFlow<LivekitRtc.SignalRequest>(Int.MAX_VALUE)

/**
* @see [onReadyForResponses]
*/
private val responseFlow = MutableSharedFlow<Pair<WebSocket, LivekitRtc.SignalResponse>>(Int.MAX_VALUE)
private val responseFlowJobLock = Object()
private var responseFlowJob: Job? = null
private val responseFlow = MutableSharedFlow<LivekitRtc.SignalResponse>(Int.MAX_VALUE)

private var pingJob: Job? = null
private var pongJob: Job? = null
Expand Down Expand Up @@ -137,7 +143,7 @@ constructor(
roomOptions: RoomOptions,
): Either<JoinResponse, Either<ReconnectResponse, Unit>> {
// Clean up any pre-existing connection.
close(reason = "Starting new connection")
close(reason = "Starting new connection", shouldClearQueuedRequests = false)

val wsUrlString = "$url/rtc" + createConnectionParams(token, getClientInfo(), options, roomOptions)
isReconnecting = options.reconnect
Expand Down Expand Up @@ -210,9 +216,9 @@ constructor(
synchronized(responseFlowJobLock) {
if (responseFlowJob == null) {
responseFlowJob = coroutineScope.launch {
responseFlow.collect {
responseFlow.collect { (ws, response) ->
responseFlow.resetReplayCache()
handleSignalResponseImpl(it)
handleSignalResponseImpl(ws, response)
}
}
}
Expand Down Expand Up @@ -246,19 +252,31 @@ constructor(

// --------------------------------- WebSocket Listener --------------------------------------//
override fun onMessage(webSocket: WebSocket, text: String) {
if (webSocket != currentWs) {
// Possibly message from old websocket, discard.
return
}

LKLog.w { "received JSON message, unsupported in this version." }
}

override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
if (webSocket != currentWs) {
// Possibly message from old websocket, discard.
return
}
val byteArray = bytes.toByteArray()
val signalResponseBuilder = LivekitRtc.SignalResponse.newBuilder()
.mergeFrom(byteArray)
val response = signalResponseBuilder.build()

handleSignalResponse(response)
handleSignalResponse(webSocket, response)
}

override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
if (webSocket != currentWs) {
return
}
handleWebSocketClose(reason, code)
}

Expand All @@ -267,6 +285,9 @@ constructor(
}

override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
if (webSocket != currentWs) {
return
}
var reason: String? = null
try {
lastUrl?.let {
Expand Down Expand Up @@ -553,7 +574,11 @@ constructor(
}
}

private fun handleSignalResponse(response: LivekitRtc.SignalResponse) {
private fun handleSignalResponse(ws: WebSocket, response: LivekitRtc.SignalResponse) {
if (ws != currentWs) {
return
}

LKLog.v { "response: $response" }

if (!isConnected) {
Expand All @@ -574,7 +599,7 @@ constructor(
joinContinuation?.resumeWith(Result.success(Either.Left(response.join)))
} else if (response.hasLeave()) {
// Some reconnects may immediately send leave back without a join response first.
handleSignalResponseImpl(response)
handleSignalResponseImpl(ws, response)
} else if (isReconnecting) {
// When reconnecting, any message received means signal reconnected.
// Newer servers will send a reconnect response first
Expand All @@ -598,10 +623,15 @@ constructor(
return
}
}
responseFlow.tryEmit(response)
responseFlow.tryEmit(ws to response)
}

private fun handleSignalResponseImpl(response: LivekitRtc.SignalResponse) {
private fun handleSignalResponseImpl(ws: WebSocket, response: LivekitRtc.SignalResponse) {
if (ws != currentWs) {
LKLog.v { "received message from old websocket, discarding." }
return
}

when (response.messageCase) {
LivekitRtc.SignalResponse.MessageCase.ANSWER -> {
val sd = fromProtoSessionDescription(response.answer)
Expand Down Expand Up @@ -738,7 +768,7 @@ constructor(
*
* Can be reused afterwards.
*/
fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure") {
fun close(code: Int = CLOSE_REASON_NORMAL_CLOSURE, reason: String = "Normal Closure", shouldClearQueuedRequests: Boolean = true) {
LKLog.v(Exception()) { "Closing SignalClient: code = $code, reason = $reason" }
isConnected = false
isReconnecting = false
Expand All @@ -757,8 +787,9 @@ constructor(
currentWs = null
joinContinuation?.cancel()
joinContinuation = null
// TODO: support calling this from connect without wiping any queued requests.
// requestFlow.resetReplayCache()
if (shouldClearQueuedRequests) {
requestFlow.resetReplayCache()
}
responseFlow.resetReplayCache()
lastUrl = null
lastOptions = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okio.ByteString
import org.junit.After
import org.junit.Before
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
Expand All @@ -60,6 +61,11 @@ abstract class MockE2ETest : BaseTest() {
wsFactory = component.websocketFactory()
}

@After
fun tearDown() {
room.release()
}

suspend fun connect(joinResponse: LivekitRtc.SignalResponse = SignalClientTest.JOIN) {
connectSignal(joinResponse)
connectPeerConnection()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,7 +52,6 @@ class MockWebSocketFactory : WebSocket.Factory {
this.listener = listener
this.request = request

onOpen?.invoke(this)
return ws
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,14 +20,11 @@ import dagger.Module
import dagger.Provides
import io.livekit.android.dagger.InjectionNames
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestCoroutineDispatcher
import javax.inject.Named

@Module
class TestCoroutinesModule(
@OptIn(ExperimentalCoroutinesApi::class)
val coroutineDispatcher: CoroutineDispatcher = TestCoroutineDispatcher(),
private val coroutineDispatcher: CoroutineDispatcher,
) {

@Provides
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,7 +47,7 @@ internal interface TestLiveKitComponent : LiveKitComponent {
interface Factory {
fun create(
@BindsInstance appContext: Context,
coroutinesModule: TestCoroutinesModule = TestCoroutinesModule(),
coroutinesModule: TestCoroutinesModule,
): TestLiveKitComponent
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ import org.robolectric.RobolectricTestRunner
@RunWith(RobolectricTestRunner::class)
class RoomReconnectionMockE2ETest : MockE2ETest() {

private fun prepareForReconnect() {
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
val softReconnectParam = wsFactory.request.url
.queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
?.toIntOrNull()
?: 0

if (softReconnectParam == 0) {
simulateMessageFromServer(SignalClientTest.JOIN)
} else {
simulateMessageFromServer(SignalClientTest.RECONNECT)
}
private fun reconnectWebsocket() {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
val softReconnectParam = wsFactory.request.url
.queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
?.toIntOrNull()
?: 0

if (softReconnectParam == 0) {
simulateMessageFromServer(SignalClientTest.JOIN)
} else {
simulateMessageFromServer(SignalClientTest.RECONNECT)
}
}

Expand All @@ -59,10 +57,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)

connect()
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
reconnectWebsocket()
connectPeerConnection()

testScheduler.advanceUntilIdle()
Expand All @@ -82,10 +80,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
fun softReconnectConfiguration() = runTest {
room.setReconnectionType(ReconnectType.FORCE_SOFT_RECONNECT)
connect()
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
reconnectWebsocket()
connectPeerConnection()

val rtcConfig = getSubscriberPeerConnection().rtcConfig
Expand All @@ -109,10 +107,10 @@ class RoomReconnectionMockE2ETest : MockE2ETest() {
),
)

prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
reconnectWebsocket()
connectPeerConnection()

testScheduler.advanceUntilIdle()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,17 @@ class RoomReconnectionTypesMockE2ETest(
)
}

private fun prepareForReconnect() {
wsFactory.onOpen = {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
val softReconnectParam = wsFactory.request.url
.queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
?.toIntOrNull()
?: 0

if (softReconnectParam == 0) {
simulateMessageFromServer(SignalClientTest.JOIN)
} else {
simulateMessageFromServer(SignalClientTest.RECONNECT)
}
private fun reconnectWebsocket() {
wsFactory.listener.onOpen(wsFactory.ws, createOpenResponse(wsFactory.request))
val softReconnectParam = wsFactory.request.url
.queryParameter(SignalClient.CONNECT_QUERY_RECONNECT)
?.toIntOrNull()
?: 0

if (softReconnectParam == 0) {
simulateMessageFromServer(SignalClientTest.JOIN)
} else {
simulateMessageFromServer(SignalClientTest.RECONNECT)
}
}

Expand Down Expand Up @@ -111,10 +109,10 @@ class RoomReconnectionTypesMockE2ETest(

val eventCollector = EventCollector(room.events, coroutineRule.scope)
val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
prepareForReconnect()
disconnectPeerConnection()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
reconnectWebsocket()
connectPeerConnection()

testScheduler.advanceUntilIdle()
Expand All @@ -138,10 +136,10 @@ class RoomReconnectionTypesMockE2ETest(

val eventCollector = EventCollector(room.events, coroutineRule.scope)
val stateCollector = FlowCollector(room::state.flow, coroutineRule.scope)
prepareForReconnect()
wsFactory.ws.cancel()
// Wait so that the reconnect job properly starts first.
testScheduler.advanceTimeBy(1000)
reconnectWebsocket()
connectPeerConnection()

testScheduler.advanceUntilIdle()
Expand Down

0 comments on commit d21a385

Please sign in to comment.