Skip to content

Instantly share code, notes, and snippets.

@belinwu
Forked from n0m0r3pa1n/OkHttpWebSocketClient.kt
Created April 23, 2024 06:30
Show Gist options
  • Save belinwu/da27d0b885cb6110d742b65ab884d12c to your computer and use it in GitHub Desktop.
Save belinwu/da27d0b885cb6110d742b65ab884d12c to your computer and use it in GitHub Desktop.
OkHttpWebSocketClient
import io.reactivex.subjects.PublishSubject
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import okhttp3.Headers.Companion.toHeaders
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
import okio.ByteString
import timber.log.Timber
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.resume
class OkHttpWebSocketClient(
private val analytics: Analytics,
private val eventFactory: WebSocketsAnalyticsEventFactory,
private val webSocketUrlCreator: () -> String,
private val webSocketHeadersProvider: () -> Map<String, String>,
private val okHttpClient: OkHttpClient,
private val flowDispatcher: CoroutineDispatcher = Dispatchers.IO,
@ProcessCoroutineScope private val processScope: CoroutineScope,
@Suppress("EXPERIMENTAL_API_USAGE")
private val connectingStatusContext: CoroutineContext = newSingleThreadContext("connectingStatusContext"),
private val maximumBackoffMillis: Long = 30000,
) : WebSocketClient {
private var reconnectionAttempts = 0
private val _incomingSubject = PublishSubject.create<String>()
private val _outgoingSubject = PublishSubject.create<String>().toSerialized()
private lateinit var socket: WebSocket
private var status: ConnectionStatus = ConnectionStatus.Disconnected
private val reconnectionMutex = Mutex()
private val toSendOnReconnection = mutableListOf<String>()
/**
* Sets a message to be automatically sent over the WebSocket upon (re-)connection
*/
override suspend fun setConnectionAutoMessage(message: String) {
reconnectionMutex.withLock {
toSendOnReconnection.add(message)
}
}
@Suppress("MagicNumber")
private fun reconnect() {
processScope.launch(connectingStatusContext) {
val delayMillis = exponentialBackoff(
currentAttempt = reconnectionAttempts,
min = 1000,
max = maximumBackoffMillis
)
delay(delayMillis)
reconnectionAttempts++
trackEvent(eventFactory.onReconnecting(reconnectionAttempts))
val connected = ensureConnected()
if (connected) {
trackEvent(eventFactory.onReconnectionSuccess())
reconnectionMutex.withLock {
toSendOnReconnection.forEach { send(it) }
toSendOnReconnection.clear()
}
} else {
trackEvent(eventFactory.onReconnectionFailure())
}
}
}
/**
* Connects to the WebSocket and opens the incoming and outgoing message channels for communication.
*
* * If the Websocket is in its [Disconnected][ConnectionStatus.Disconnected] state, a connection will be attempted
* * If the WebSocket is already [Connected][ConnectionStatus.Connected], the coroutine will end immediately
* * If the WebSocket is [Connecting][ConnectionStatus.Connecting], the coroutine will wait
* until the connection attempt either succeeds or fails
*
* Any connection failures or closures of the WebSocket will attempt a reconnection until connected,
* with an exponential backoff, via [reconnect].
* This means that calls to this coroutine guarantee the WebSocket will be at a connected state at some point, if possible.
*
* **IMPORTANT**: This behavior may change in the future once we move the re-connection attempts outside this client.
*
* @return `true` if the connection attempt was successful, `false` otherwise
*/
override suspend fun ensureConnected(): Boolean = suspendCancellableCoroutine { continuation ->
when (status) {
ConnectionStatus.Disconnected -> {
status = ConnectionStatus.Connecting
trackEvent(eventFactory.onConnecting())
socket = okHttpClient.newWebSocket(
request = Request.Builder()
.url(webSocketUrlCreator.invoke())
.headers(webSocketHeadersProvider.invoke().toHeaders())
.build(),
listener = object : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) {
Timber.v("Connected %s", response)
status = ConnectionStatus.Connected
reconnectionAttempts = 0
trackEvent(eventFactory.onConnected())
continuation.resume(true)
}
override fun onMessage(webSocket: WebSocket, text: String) {
Timber.v("Received $text")
trackEvent(eventFactory.onReceiveMessage())
_incomingSubject.onNext(text)
}
override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
Timber.v("Received $bytes")
trackEvent(eventFactory.onReceiveMessage())
_incomingSubject.onNext(bytes.toString())
}
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
Timber.w("Closed. Reason %s", reason)
trackEvent(eventFactory.onSocketClosed(reason))
onDisconnect()
}
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
trackEvent(
eventFactory.onSocketConnectionFailure(
reason = t.toString(),
message = response?.message ?: ""
)
)
if (continuation.isActive) {
Timber.e(t, "WebSocket connection failed %s. %s", response, t.message)
continuation.resume(false)
} else {
Timber.w(t, "WebSocket closed")
}
onDisconnect()
reconnect()
}
private fun onDisconnect() {
trackEvent(eventFactory.onSocketDisconnected())
status = ConnectionStatus.Disconnected
}
},
).also { Timber.v("Connecting...") }
}
ConnectionStatus.Connecting -> {
}
ConnectionStatus.Connected -> continuation.resume(true)
}
}
private fun trackEvent(event: AnalyticsEvent.Event) = analytics.track(event)
init {
_outgoingSubject
.doOnEach {
trackEvent(eventFactory.onSendingMessage())
Timber.v("Sending ${it.value}")
it.value?.let { message -> socket.send(message) }
}
.ignoreElements()
.onErrorComplete()
.subscribe()
}
override fun readFlow(): Flow<String> = _incomingSubject.asFlow().flowOn(flowDispatcher)
override fun isClosed(): Boolean = status == ConnectionStatus.Disconnected
override fun close(code: Int, reason: String?) {
socket.close(code, reason)
}
override suspend fun send(data: String) {
_outgoingSubject.onNext(data)
}
private enum class ConnectionStatus {
Disconnected,
Connecting,
Connected
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment