Skip to content

Instantly share code, notes, and snippets.

@chachako
Created January 14, 2024 08:08
Show Gist options
  • Save chachako/58598a31f8f33702d9926e474ba291f1 to your computer and use it in GitHub Desktop.
Save chachako/58598a31f8f33702d9926e474ba291f1 to your computer and use it in GitHub Desktop.
Ktor SSE Post
@file:OptIn(InternalAPI::class)
package chachako.network.sse
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.request.HttpResponseData
import io.ktor.sse.COLON
import io.ktor.sse.END_OF_LINE
import io.ktor.sse.SPACE
import io.ktor.sse.ServerSentEvent
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.InternalAPI
import io.ktor.utils.io.charsets.MalformedInputException
import io.ktor.utils.io.readUTF8Line
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import chachako.common.debug
import kotlin.coroutines.CoroutineContext
internal class DefaultClientSSESession(
private val errorResponse: HttpResponseData?,
private var reconnectionTimeMillis: Long,
private val showCommentEvents: Boolean,
private val showRetryEvents: Boolean,
private var input: ByteReadChannel,
override val coroutineContext: CoroutineContext,
) : ClientSSESession {
private var lastEventId: String? = null
private val _incoming = channelFlow {
if (errorResponse != null) {
val responseText = try {
input.readRemaining().readText()
} catch (_: MalformedInputException) {
"<body failed decoding>"
}
debug { "Received SSE response error: $responseText" }
throw SSEResponseException(errorResponse.statusCode, responseText)
}
while (true) {
val event = input.parseEvent() ?: break
if (event.isCommentsEvent() && !showCommentEvents) continue
if (event.isRetryEvent() && !showRetryEvents) continue
debug { "Received SSE: $event" }
// Ignore the end mark like OPENAI API. We don't want to send this kind of meaningless
// event to the client
if (event.data == "[DONE]") break
send(event)
}
}
override val incoming: Flow<ServerSentEvent>
get() = _incoming
private suspend fun ByteReadChannel.parseEvent(): ServerSentEvent? {
val data = StringBuilder()
val comments = StringBuilder()
var eventType: String? = null
var curRetry: Long? = null
var lastEventId: String? = [email protected]
var wasData = false
var wasComments = false
var line: String = readUTF8Line() ?: return null
while (line.isBlank()) {
line = readUTF8Line() ?: return null
}
while (true) {
when {
line.isBlank() -> {
[email protected] = lastEventId
val event = ServerSentEvent(
if (wasData) data.toText() else null,
eventType,
lastEventId,
curRetry,
if (wasComments) comments.toText() else null
)
if (!event.isEmpty()) {
return event
}
}
line.startsWith(COLON) -> {
wasComments = true
comments.appendComment(line)
}
else -> {
val field = line.substringBefore(COLON)
val value = line.substringAfter(COLON, missingDelimiterValue = "").removePrefix(SPACE)
when (field) {
"event" -> eventType = value
"data" -> {
wasData = true
data.append(value).append(END_OF_LINE)
}
"retry" -> value.toLongOrNull()?.let {
reconnectionTimeMillis = it
curRetry = it
}
"id" -> if (!value.contains(NULL)) {
lastEventId = value
}
}
}
}
line = readUTF8Line() ?: return null
}
}
private fun StringBuilder.appendComment(comment: String) {
append(comment.removePrefix(COLON).removePrefix(SPACE)).append(END_OF_LINE)
}
private fun StringBuilder.toText() = toString().removeSuffix(END_OF_LINE)
private fun ServerSentEvent.isEmpty() =
data == null && id == null && event == null && retry == null && comments == null
private fun ServerSentEvent.isCommentsEvent() =
data == null && event == null && id == null && retry == null && comments != null
private fun ServerSentEvent.isRetryEvent() =
data == null && event == null && id == null && comments == null && retry != null
}
private const val NULL = "\u0000"
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
@file:OptIn(InternalAPI::class)
package chachako.network.engine
import io.ktor.client.engine.HttpClientEngineBase
import io.ktor.client.engine.HttpClientEngineCapability
import io.ktor.client.engine.okhttp.OkHttpConfig
import io.ktor.client.engine.okhttp.OkHttpEngine
import io.ktor.client.plugins.sse.reconnectionTimeAttr
import io.ktor.client.plugins.sse.showCommentEventsAttr
import io.ktor.client.plugins.sse.showRetryEventsAttr
import io.ktor.client.request.HttpRequestData
import io.ktor.client.request.HttpResponseData
import io.ktor.http.HttpStatusCode
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.InternalAPI
import kotlinx.coroutines.CoroutineDispatcher
import kotlin.coroutines.CoroutineContext
class OkHttpClientEngineWrapper(override val config: OkHttpConfig) : HttpClientEngineBase("okhttp-wrapper") {
private val okhttp = OkHttpEngine(config)
override val supportedCapabilities: Set<HttpClientEngineCapability<*>> = okhttp.supportedCapabilities
override val coroutineContext: CoroutineContext get() = okhttp.coroutineContext
override val dispatcher: CoroutineDispatcher get() = okhttp.dispatcher
override fun close() {
super.close()
okhttp.close()
}
override suspend fun execute(data: HttpRequestData): HttpResponseData {
val responseData = okhttp.execute(data)
// Process SSE response.
if (data.isSseRequest()) return HttpResponseData(
statusCode = responseData.statusCode,
requestTime = responseData.responseTime,
headers = responseData.headers,
version = responseData.version,
body = DefaultClientSSESession(
errorResponse = if (responseData.statusCode != HttpStatusCode.OK) responseData else null,
reconnectionTimeMillis = data.attributes[reconnectionTimeAttr].inWholeMilliseconds,
showCommentEvents = data.attributes[showCommentEventsAttr],
showRetryEvents = data.attributes[showRetryEventsAttr],
input = responseData.body as ByteReadChannel,
coroutineContext = responseData.callContext,
),
callContext = responseData.callContext
)
return responseData
}
}
package chachako.network.engine
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.HttpClientEngineFactory
import io.ktor.client.engine.okhttp.OkHttpConfig
object OkHttpClientWrapper : HttpClientEngineFactory<OkHttpConfig> {
override fun create(block: OkHttpConfig.() -> Unit): HttpClientEngine =
OkHttpClientEngineWrapper(OkHttpConfig().apply(block))
}
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package chachako.network.sse
import io.ktor.client.plugins.HttpTimeoutConfig.Companion.INFINITE_TIMEOUT_MS
import io.ktor.client.plugins.api.ClientPlugin
import io.ktor.client.plugins.api.createClientPlugin
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.plugins.sse.SSECapability
import io.ktor.client.plugins.sse.SSEConfig
import io.ktor.client.plugins.sse.reconnectionTimeAttr
import io.ktor.client.plugins.sse.showCommentEventsAttr
import io.ktor.client.plugins.sse.showRetryEventsAttr
import io.ktor.client.plugins.timeout
import io.ktor.client.request.HttpRequest
import io.ktor.client.request.HttpRequestData
import io.ktor.client.request.headers
import io.ktor.client.statement.HttpResponseContainer
import io.ktor.client.statement.HttpResponsePipeline
import io.ktor.client.statement.request
import io.ktor.client.utils.EmptyContent
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.append
import io.ktor.sse.SSEException
import io.ktor.util.AttributeKey
import chachako.common.debug
internal val sseRequestAttr = AttributeKey<Boolean>("SSERequest")
/**
* Client Server-sent events plugin that allows you to establish an SSE connection to a server
* and receive Server-sent events from it.
*
* ```kotlin
* val client = HttpClient {
* install(SSE)
* }
* client.sse {
* val event = incoming.receive()
* }
* ```
*/
val SSE: ClientPlugin<SSEConfig> = createClientPlugin(
name = "SSE",
createConfiguration = ::SSEConfig
) {
val reconnectionTime = pluginConfig.reconnectionTime
val showCommentEvents = pluginConfig.showCommentEvents
val showRetryEvents = pluginConfig.showRetryEvents
onRequest { request, _ ->
request.headers {
append(HttpHeaders.Accept, ContentType.Text.EventStream)
append(HttpHeaders.CacheControl, "no-store")
}
request.timeout {
requestTimeoutMillis = INFINITE_TIMEOUT_MS
connectTimeoutMillis = INFINITE_TIMEOUT_MS
socketTimeoutMillis = INFINITE_TIMEOUT_MS
}
request.setCapability(SSECapability, Unit)
val hasReconnectionTime = request.attributes.contains(reconnectionTimeAttr)
val hasShowCommentEvents = request.attributes.contains(showCommentEventsAttr)
val hasShowRetryEvents = request.attributes.contains(showRetryEventsAttr)
if (request.body !is EmptyContent && request.method == HttpMethod.Get) {
request.method = HttpMethod.Post
}
request.attributes.apply {
put(AttributeKey<Unit>("ResponseBodySaved"), Unit)
put(sseRequestAttr, true)
if (!hasReconnectionTime) put(reconnectionTimeAttr, reconnectionTime)
if (!hasShowCommentEvents) put(showCommentEventsAttr, showCommentEvents)
if (!hasShowRetryEvents) put(showRetryEventsAttr, showRetryEvents)
}
}
client.responsePipeline.intercept(HttpResponsePipeline.Transform) { (info, session) ->
val response = context.response
if (!response.request.isSseRequest()) {
debug { "Skipping non SSE response from ${response.request.url}" }
return@intercept
}
if (session !is ClientSSESession) {
throw SSEException("Expected `ClientSSESession` content but was: $session")
}
debug { "Received SSE response from ${response.request.url}: $session" }
proceedWith(HttpResponseContainer(info, session))
}
}
class SSEResponseException(
val statusCode: HttpStatusCode,
val responseText: String
) : IllegalStateException("Bad response: statusCode=${statusCode.value}. Text: \"$responseText\"")
internal fun HttpRequestData.isSseRequest(): Boolean =
attributes.getOrNull(sseRequestAttr) == true
internal fun HttpRequest.isSseRequest(): Boolean =
attributes.getOrNull(sseRequestAttr) == true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment