Last active
April 23, 2024 06:30
-
-
Save n0m0r3pa1n/a7e5283e66a0f446ded47f5cbe7cf60a to your computer and use it in GitHub Desktop.
OkHttpWebSocketClient
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io.reactivex.subjects.PublishSubject | |
import kotlinx.coroutines.CoroutineDispatcher | |
import kotlinx.coroutines.CoroutineScope | |
import kotlinx.coroutines.Dispatchers | |
import kotlinx.coroutines.delay | |
import kotlinx.coroutines.flow.Flow | |
import kotlinx.coroutines.flow.flowOn | |
import kotlinx.coroutines.launch | |
import kotlinx.coroutines.newSingleThreadContext | |
import kotlinx.coroutines.suspendCancellableCoroutine | |
import kotlinx.coroutines.sync.Mutex | |
import kotlinx.coroutines.sync.withLock | |
import okhttp3.Headers.Companion.toHeaders | |
import okhttp3.OkHttpClient | |
import okhttp3.Request | |
import okhttp3.Response | |
import okhttp3.WebSocket | |
import okhttp3.WebSocketListener | |
import okio.ByteString | |
import timber.log.Timber | |
import kotlin.coroutines.CoroutineContext | |
import kotlin.coroutines.resume | |
class OkHttpWebSocketClient( | |
private val analytics: Analytics, | |
private val eventFactory: WebSocketsAnalyticsEventFactory, | |
private val webSocketUrlCreator: () -> String, | |
private val webSocketHeadersProvider: () -> Map<String, String>, | |
private val okHttpClient: OkHttpClient, | |
private val flowDispatcher: CoroutineDispatcher = Dispatchers.IO, | |
@ProcessCoroutineScope private val processScope: CoroutineScope, | |
@Suppress("EXPERIMENTAL_API_USAGE") | |
private val connectingStatusContext: CoroutineContext = newSingleThreadContext("connectingStatusContext"), | |
private val maximumBackoffMillis: Long = 30000, | |
) : WebSocketClient { | |
private var reconnectionAttempts = 0 | |
private val _incomingSubject = PublishSubject.create<String>() | |
private val _outgoingSubject = PublishSubject.create<String>().toSerialized() | |
private lateinit var socket: WebSocket | |
private var status: ConnectionStatus = ConnectionStatus.Disconnected | |
private val reconnectionMutex = Mutex() | |
private val toSendOnReconnection = mutableListOf<String>() | |
/** | |
* Sets a message to be automatically sent over the WebSocket upon (re-)connection | |
*/ | |
override suspend fun setConnectionAutoMessage(message: String) { | |
reconnectionMutex.withLock { | |
toSendOnReconnection.add(message) | |
} | |
} | |
@Suppress("MagicNumber") | |
private fun reconnect() { | |
processScope.launch(connectingStatusContext) { | |
val delayMillis = exponentialBackoff( | |
currentAttempt = reconnectionAttempts, | |
min = 1000, | |
max = maximumBackoffMillis | |
) | |
delay(delayMillis) | |
reconnectionAttempts++ | |
trackEvent(eventFactory.onReconnecting(reconnectionAttempts)) | |
val connected = ensureConnected() | |
if (connected) { | |
trackEvent(eventFactory.onReconnectionSuccess()) | |
reconnectionMutex.withLock { | |
toSendOnReconnection.forEach { send(it) } | |
toSendOnReconnection.clear() | |
} | |
} else { | |
trackEvent(eventFactory.onReconnectionFailure()) | |
} | |
} | |
} | |
/** | |
* Connects to the WebSocket and opens the incoming and outgoing message channels for communication. | |
* | |
* * If the Websocket is in its [Disconnected][ConnectionStatus.Disconnected] state, a connection will be attempted | |
* * If the WebSocket is already [Connected][ConnectionStatus.Connected], the coroutine will end immediately | |
* * If the WebSocket is [Connecting][ConnectionStatus.Connecting], the coroutine will wait | |
* until the connection attempt either succeeds or fails | |
* | |
* Any connection failures or closures of the WebSocket will attempt a reconnection until connected, | |
* with an exponential backoff, via [reconnect]. | |
* This means that calls to this coroutine guarantee the WebSocket will be at a connected state at some point, if possible. | |
* | |
* **IMPORTANT**: This behavior may change in the future once we move the re-connection attempts outside this client. | |
* | |
* @return `true` if the connection attempt was successful, `false` otherwise | |
*/ | |
override suspend fun ensureConnected(): Boolean = suspendCancellableCoroutine { continuation -> | |
when (status) { | |
ConnectionStatus.Disconnected -> { | |
status = ConnectionStatus.Connecting | |
trackEvent(eventFactory.onConnecting()) | |
socket = okHttpClient.newWebSocket( | |
request = Request.Builder() | |
.url(webSocketUrlCreator.invoke()) | |
.headers(webSocketHeadersProvider.invoke().toHeaders()) | |
.build(), | |
listener = object : WebSocketListener() { | |
override fun onOpen(webSocket: WebSocket, response: Response) { | |
Timber.v("Connected %s", response) | |
status = ConnectionStatus.Connected | |
reconnectionAttempts = 0 | |
trackEvent(eventFactory.onConnected()) | |
continuation.resume(true) | |
} | |
override fun onMessage(webSocket: WebSocket, text: String) { | |
Timber.v("Received $text") | |
trackEvent(eventFactory.onReceiveMessage()) | |
_incomingSubject.onNext(text) | |
} | |
override fun onMessage(webSocket: WebSocket, bytes: ByteString) { | |
Timber.v("Received $bytes") | |
trackEvent(eventFactory.onReceiveMessage()) | |
_incomingSubject.onNext(bytes.toString()) | |
} | |
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { | |
Timber.w("Closed. Reason %s", reason) | |
trackEvent(eventFactory.onSocketClosed(reason)) | |
onDisconnect() | |
} | |
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { | |
trackEvent( | |
eventFactory.onSocketConnectionFailure( | |
reason = t.toString(), | |
message = response?.message ?: "" | |
) | |
) | |
if (continuation.isActive) { | |
Timber.e(t, "WebSocket connection failed %s. %s", response, t.message) | |
continuation.resume(false) | |
} else { | |
Timber.w(t, "WebSocket closed") | |
} | |
onDisconnect() | |
reconnect() | |
} | |
private fun onDisconnect() { | |
trackEvent(eventFactory.onSocketDisconnected()) | |
status = ConnectionStatus.Disconnected | |
} | |
}, | |
).also { Timber.v("Connecting...") } | |
} | |
ConnectionStatus.Connecting -> { | |
} | |
ConnectionStatus.Connected -> continuation.resume(true) | |
} | |
} | |
private fun trackEvent(event: AnalyticsEvent.Event) = analytics.track(event) | |
init { | |
_outgoingSubject | |
.doOnEach { | |
trackEvent(eventFactory.onSendingMessage()) | |
Timber.v("Sending ${it.value}") | |
it.value?.let { message -> socket.send(message) } | |
} | |
.ignoreElements() | |
.onErrorComplete() | |
.subscribe() | |
} | |
override fun readFlow(): Flow<String> = _incomingSubject.asFlow().flowOn(flowDispatcher) | |
override fun isClosed(): Boolean = status == ConnectionStatus.Disconnected | |
override fun close(code: Int, reason: String?) { | |
socket.close(code, reason) | |
} | |
override suspend fun send(data: String) { | |
_outgoingSubject.onNext(data) | |
} | |
private enum class ConnectionStatus { | |
Disconnected, | |
Connecting, | |
Connected | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment