Created
January 14, 2024 08:08
-
-
Save chachako/58598a31f8f33702d9926e474ba291f1 to your computer and use it in GitHub Desktop.
Ktor SSE Post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:OptIn(InternalAPI::class) | |
package chachako.network.sse | |
import io.ktor.client.plugins.sse.ClientSSESession | |
import io.ktor.client.request.HttpResponseData | |
import io.ktor.sse.COLON | |
import io.ktor.sse.END_OF_LINE | |
import io.ktor.sse.SPACE | |
import io.ktor.sse.ServerSentEvent | |
import io.ktor.utils.io.ByteReadChannel | |
import io.ktor.utils.io.InternalAPI | |
import io.ktor.utils.io.charsets.MalformedInputException | |
import io.ktor.utils.io.readUTF8Line | |
import kotlinx.coroutines.flow.Flow | |
import kotlinx.coroutines.flow.channelFlow | |
import chachako.common.debug | |
import kotlin.coroutines.CoroutineContext | |
internal class DefaultClientSSESession( | |
private val errorResponse: HttpResponseData?, | |
private var reconnectionTimeMillis: Long, | |
private val showCommentEvents: Boolean, | |
private val showRetryEvents: Boolean, | |
private var input: ByteReadChannel, | |
override val coroutineContext: CoroutineContext, | |
) : ClientSSESession { | |
private var lastEventId: String? = null | |
private val _incoming = channelFlow { | |
if (errorResponse != null) { | |
val responseText = try { | |
input.readRemaining().readText() | |
} catch (_: MalformedInputException) { | |
"<body failed decoding>" | |
} | |
debug { "Received SSE response error: $responseText" } | |
throw SSEResponseException(errorResponse.statusCode, responseText) | |
} | |
while (true) { | |
val event = input.parseEvent() ?: break | |
if (event.isCommentsEvent() && !showCommentEvents) continue | |
if (event.isRetryEvent() && !showRetryEvents) continue | |
debug { "Received SSE: $event" } | |
// Ignore the end mark like OPENAI API. We don't want to send this kind of meaningless | |
// event to the client | |
if (event.data == "[DONE]") break | |
send(event) | |
} | |
} | |
override val incoming: Flow<ServerSentEvent> | |
get() = _incoming | |
private suspend fun ByteReadChannel.parseEvent(): ServerSentEvent? { | |
val data = StringBuilder() | |
val comments = StringBuilder() | |
var eventType: String? = null | |
var curRetry: Long? = null | |
var lastEventId: String? = [email protected] | |
var wasData = false | |
var wasComments = false | |
var line: String = readUTF8Line() ?: return null | |
while (line.isBlank()) { | |
line = readUTF8Line() ?: return null | |
} | |
while (true) { | |
when { | |
line.isBlank() -> { | |
[email protected] = lastEventId | |
val event = ServerSentEvent( | |
if (wasData) data.toText() else null, | |
eventType, | |
lastEventId, | |
curRetry, | |
if (wasComments) comments.toText() else null | |
) | |
if (!event.isEmpty()) { | |
return event | |
} | |
} | |
line.startsWith(COLON) -> { | |
wasComments = true | |
comments.appendComment(line) | |
} | |
else -> { | |
val field = line.substringBefore(COLON) | |
val value = line.substringAfter(COLON, missingDelimiterValue = "").removePrefix(SPACE) | |
when (field) { | |
"event" -> eventType = value | |
"data" -> { | |
wasData = true | |
data.append(value).append(END_OF_LINE) | |
} | |
"retry" -> value.toLongOrNull()?.let { | |
reconnectionTimeMillis = it | |
curRetry = it | |
} | |
"id" -> if (!value.contains(NULL)) { | |
lastEventId = value | |
} | |
} | |
} | |
} | |
line = readUTF8Line() ?: return null | |
} | |
} | |
private fun StringBuilder.appendComment(comment: String) { | |
append(comment.removePrefix(COLON).removePrefix(SPACE)).append(END_OF_LINE) | |
} | |
private fun StringBuilder.toText() = toString().removeSuffix(END_OF_LINE) | |
private fun ServerSentEvent.isEmpty() = | |
data == null && id == null && event == null && retry == null && comments == null | |
private fun ServerSentEvent.isCommentsEvent() = | |
data == null && event == null && id == null && retry == null && comments != null | |
private fun ServerSentEvent.isRetryEvent() = | |
data == null && event == null && id == null && comments == null && retry != null | |
} | |
private const val NULL = "\u0000" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") | |
@file:OptIn(InternalAPI::class) | |
package chachako.network.engine | |
import io.ktor.client.engine.HttpClientEngineBase | |
import io.ktor.client.engine.HttpClientEngineCapability | |
import io.ktor.client.engine.okhttp.OkHttpConfig | |
import io.ktor.client.engine.okhttp.OkHttpEngine | |
import io.ktor.client.plugins.sse.reconnectionTimeAttr | |
import io.ktor.client.plugins.sse.showCommentEventsAttr | |
import io.ktor.client.plugins.sse.showRetryEventsAttr | |
import io.ktor.client.request.HttpRequestData | |
import io.ktor.client.request.HttpResponseData | |
import io.ktor.http.HttpStatusCode | |
import io.ktor.utils.io.ByteReadChannel | |
import io.ktor.utils.io.InternalAPI | |
import kotlinx.coroutines.CoroutineDispatcher | |
import kotlin.coroutines.CoroutineContext | |
class OkHttpClientEngineWrapper(override val config: OkHttpConfig) : HttpClientEngineBase("okhttp-wrapper") { | |
private val okhttp = OkHttpEngine(config) | |
override val supportedCapabilities: Set<HttpClientEngineCapability<*>> = okhttp.supportedCapabilities | |
override val coroutineContext: CoroutineContext get() = okhttp.coroutineContext | |
override val dispatcher: CoroutineDispatcher get() = okhttp.dispatcher | |
override fun close() { | |
super.close() | |
okhttp.close() | |
} | |
override suspend fun execute(data: HttpRequestData): HttpResponseData { | |
val responseData = okhttp.execute(data) | |
// Process SSE response. | |
if (data.isSseRequest()) return HttpResponseData( | |
statusCode = responseData.statusCode, | |
requestTime = responseData.responseTime, | |
headers = responseData.headers, | |
version = responseData.version, | |
body = DefaultClientSSESession( | |
errorResponse = if (responseData.statusCode != HttpStatusCode.OK) responseData else null, | |
reconnectionTimeMillis = data.attributes[reconnectionTimeAttr].inWholeMilliseconds, | |
showCommentEvents = data.attributes[showCommentEventsAttr], | |
showRetryEvents = data.attributes[showRetryEventsAttr], | |
input = responseData.body as ByteReadChannel, | |
coroutineContext = responseData.callContext, | |
), | |
callContext = responseData.callContext | |
) | |
return responseData | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package chachako.network.engine | |
import io.ktor.client.engine.HttpClientEngine | |
import io.ktor.client.engine.HttpClientEngineFactory | |
import io.ktor.client.engine.okhttp.OkHttpConfig | |
object OkHttpClientWrapper : HttpClientEngineFactory<OkHttpConfig> { | |
override fun create(block: OkHttpConfig.() -> Unit): HttpClientEngine = | |
OkHttpClientEngineWrapper(OkHttpConfig().apply(block)) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") | |
package chachako.network.sse | |
import io.ktor.client.plugins.HttpTimeoutConfig.Companion.INFINITE_TIMEOUT_MS | |
import io.ktor.client.plugins.api.ClientPlugin | |
import io.ktor.client.plugins.api.createClientPlugin | |
import io.ktor.client.plugins.sse.ClientSSESession | |
import io.ktor.client.plugins.sse.SSECapability | |
import io.ktor.client.plugins.sse.SSEConfig | |
import io.ktor.client.plugins.sse.reconnectionTimeAttr | |
import io.ktor.client.plugins.sse.showCommentEventsAttr | |
import io.ktor.client.plugins.sse.showRetryEventsAttr | |
import io.ktor.client.plugins.timeout | |
import io.ktor.client.request.HttpRequest | |
import io.ktor.client.request.HttpRequestData | |
import io.ktor.client.request.headers | |
import io.ktor.client.statement.HttpResponseContainer | |
import io.ktor.client.statement.HttpResponsePipeline | |
import io.ktor.client.statement.request | |
import io.ktor.client.utils.EmptyContent | |
import io.ktor.http.ContentType | |
import io.ktor.http.HttpHeaders | |
import io.ktor.http.HttpMethod | |
import io.ktor.http.HttpStatusCode | |
import io.ktor.http.append | |
import io.ktor.sse.SSEException | |
import io.ktor.util.AttributeKey | |
import chachako.common.debug | |
internal val sseRequestAttr = AttributeKey<Boolean>("SSERequest") | |
/** | |
* Client Server-sent events plugin that allows you to establish an SSE connection to a server | |
* and receive Server-sent events from it. | |
* | |
* ```kotlin | |
* val client = HttpClient { | |
* install(SSE) | |
* } | |
* client.sse { | |
* val event = incoming.receive() | |
* } | |
* ``` | |
*/ | |
val SSE: ClientPlugin<SSEConfig> = createClientPlugin( | |
name = "SSE", | |
createConfiguration = ::SSEConfig | |
) { | |
val reconnectionTime = pluginConfig.reconnectionTime | |
val showCommentEvents = pluginConfig.showCommentEvents | |
val showRetryEvents = pluginConfig.showRetryEvents | |
onRequest { request, _ -> | |
request.headers { | |
append(HttpHeaders.Accept, ContentType.Text.EventStream) | |
append(HttpHeaders.CacheControl, "no-store") | |
} | |
request.timeout { | |
requestTimeoutMillis = INFINITE_TIMEOUT_MS | |
connectTimeoutMillis = INFINITE_TIMEOUT_MS | |
socketTimeoutMillis = INFINITE_TIMEOUT_MS | |
} | |
request.setCapability(SSECapability, Unit) | |
val hasReconnectionTime = request.attributes.contains(reconnectionTimeAttr) | |
val hasShowCommentEvents = request.attributes.contains(showCommentEventsAttr) | |
val hasShowRetryEvents = request.attributes.contains(showRetryEventsAttr) | |
if (request.body !is EmptyContent && request.method == HttpMethod.Get) { | |
request.method = HttpMethod.Post | |
} | |
request.attributes.apply { | |
put(AttributeKey<Unit>("ResponseBodySaved"), Unit) | |
put(sseRequestAttr, true) | |
if (!hasReconnectionTime) put(reconnectionTimeAttr, reconnectionTime) | |
if (!hasShowCommentEvents) put(showCommentEventsAttr, showCommentEvents) | |
if (!hasShowRetryEvents) put(showRetryEventsAttr, showRetryEvents) | |
} | |
} | |
client.responsePipeline.intercept(HttpResponsePipeline.Transform) { (info, session) -> | |
val response = context.response | |
if (!response.request.isSseRequest()) { | |
debug { "Skipping non SSE response from ${response.request.url}" } | |
return@intercept | |
} | |
if (session !is ClientSSESession) { | |
throw SSEException("Expected `ClientSSESession` content but was: $session") | |
} | |
debug { "Received SSE response from ${response.request.url}: $session" } | |
proceedWith(HttpResponseContainer(info, session)) | |
} | |
} | |
class SSEResponseException( | |
val statusCode: HttpStatusCode, | |
val responseText: String | |
) : IllegalStateException("Bad response: statusCode=${statusCode.value}. Text: \"$responseText\"") | |
internal fun HttpRequestData.isSseRequest(): Boolean = | |
attributes.getOrNull(sseRequestAttr) == true | |
internal fun HttpRequest.isSseRequest(): Boolean = | |
attributes.getOrNull(sseRequestAttr) == true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment