Skip to content

Instantly share code, notes, and snippets.

@Vaysman
Forked from AniFichadia/$KtorReverseProxy.kt
Created January 25, 2025 02:57
Show Gist options
  • Save Vaysman/48ce5b3732576b62fedfc3d99dd3877f to your computer and use it in GitHub Desktop.
Save Vaysman/48ce5b3732576b62fedfc3d99dd3877f to your computer and use it in GitHub Desktop.
Ktor reverse proxy that forwards the entire request contents (including the body). This is based off https://github.com/ktorio/ktor-samples/blob/1.3.0/other/reverse-proxy/src/ReverseProxyApplication.kt
package com.anifichadia.ktorreverseproxy
import io.ktor.application.Application
import io.ktor.application.ApplicationCallPipeline
import io.ktor.application.ApplicationStopping
import io.ktor.application.call
import io.ktor.application.install
import io.ktor.application.log
import io.ktor.client.HttpClient
import io.ktor.client.features.HttpTimeout
import io.ktor.features.CallLogging
import io.ktor.request.httpMethod
import io.ktor.request.uri
import org.slf4j.event.Level
import java.util.concurrent.TimeUnit
/**
* Based off: https://github.com/ktorio/ktor-samples/blob/1.3.0/other/reverse-proxy/src/ReverseProxyApplication.kt
* Contains a few changes to forward the request body also
*/
class KtorReverseProxy(
val fallback: Fallback,
val mappings: List<Mapping> = emptyList(),
) {
fun configure(application: Application) = application.apply {
val client = HttpClient {
// Took me a good couple of hours to figure out why error responses weren't being forwarded. HttpClient's
// default `expectSuccess` to true. When it is true, the HttpClient pipeline response validation step throws
// an exception instead of consuming the response. For a gateway, we're expecting potential error responses,
// which we want to consume and forward to clients.
expectSuccess = false
install(HttpTimeout) {
val timeoutMillis = TimeUnit.SECONDS.toMillis(60)
requestTimeoutMillis = timeoutMillis
connectTimeoutMillis = timeoutMillis
socketTimeoutMillis = timeoutMillis
}
install(CallLogging) {
logger = log
level = Level.TRACE
}
}.apply {
environment.monitor.subscribe(ApplicationStopping) {
close()
}
}
intercept(ApplicationCallPipeline.Call) {
val requestUri = call.request.uri
val requestMethod = call.request.httpMethod
val requestMethodValue = requestMethod.value
val receiveChannel = call.request.receiveChannel()
val hasBody = receiveChannel.availableForRead > 0
log.debug("<-- $requestUri ($requestMethodValue, body: $hasBody)")
val mapping = mappings.find { it.matches(call.request) }
if (mapping != null) {
proxyRequest(client, mapping.redirectTo)
} else {
log.debug("$requestUri: Using fallback")
fallback.handle(this, client)
}
}
}
}
package com.anifichadia.ktorreverseproxy
import com.anifichadia.ktorreverseproxy.Mapping.RegexMapping
import com.anifichadia.ktorreverseproxy.Mapping.SimpleMapping
import io.ktor.server.engine.embeddedServer
import io.ktor.server.netty.Netty
object KtorReverseProxyHost {
@JvmStatic
fun main(args: Array<String>) {
embeddedServer(
Netty,
port = 8080
) {
KtorReverseProxy(
fallback = Fallback.ToUrl("http://localhost:8081"),
mappings = listOf(
SimpleMapping(
redirectTo = "http://localhost:8082",
matcherString = "/my/specific/route"
),
RegexMapping(
redirectTo = "http://localhost:8083",
matcherString = """/my/regex/matched/route.*"""
),
),
).configure(this)
}.start(wait = true)
}
}
package com.anifichadia.ktorreverseproxy
import io.ktor.application.ApplicationCall
import io.ktor.application.application
import io.ktor.application.call
import io.ktor.application.log
import io.ktor.client.HttpClient
import io.ktor.client.features.ResponseException
import io.ktor.client.request.header
import io.ktor.client.request.request
import io.ktor.client.statement.HttpResponse
import io.ktor.http.ContentType
import io.ktor.http.ContentType.Companion
import io.ktor.http.Headers
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.content.OutgoingContent.ReadChannelContent
import io.ktor.http.content.OutgoingContent.WriteChannelContent
import io.ktor.request.httpMethod
import io.ktor.request.uri
import io.ktor.response.respond
import io.ktor.util.filter
import io.ktor.util.pipeline.PipelineContext
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.copyAndClose
/** Lower-cased headers that should not be proxied to the mappings or client */
val excludedProxiedHeaders = listOf(
HttpHeaders.ContentType,
HttpHeaders.ContentLength,
HttpHeaders.TransferEncoding,
).map { it.toLowerCase() }
suspend fun PipelineContext<Unit, ApplicationCall>.proxyRequest(client: HttpClient, redirectTo: String) {
val startTime = System.currentTimeMillis()
val requestUri = call.request.uri
val requestMethod = call.request.httpMethod
val requestMethodValue = requestMethod.value
val requestHeaders = call.request.headers
val completeProxiedUrl = "$redirectTo$requestUri"
log.debug("$requestUri ($requestMethodValue) -> $completeProxiedUrl")
val proxiedResponse: HttpResponse = try {
val res = client.request<HttpResponse>(completeProxiedUrl) {
// Proxy the original request method to mapping to the mapped endpoint
method = requestMethod
// Proxy the original request headers (minus the excluded) to the mapped endpoint
requestHeaders
.filter { key, _ -> key.toLowerCase() !in excludedProxiedHeaders }
.forEach { key, value -> header(key, value.joinToString()) }
// Proxy the original request body to mapping to the mapped endpoint
body = object : ReadChannelContent() {
override val contentType: ContentType? = requestHeaders[HttpHeaders.ContentType]?.let(ContentType.Companion::parse)
override fun readFrom(): ByteReadChannel = call.request.receiveChannel()
}
}
res
} catch (e: ResponseException) {
e.response
} catch (e: Throwable) {
log.error("--> $requestUri ($requestMethodValue)", e)
throw e
}
val proxyTime = System.currentTimeMillis()
val proxiedResponseHeaders = proxiedResponse.headers
call.respond(object : WriteChannelContent() {
override val contentLength: Long? = proxiedResponseHeaders[HttpHeaders.ContentLength]?.toLong()
override val contentType: ContentType? = proxiedResponseHeaders[HttpHeaders.ContentType]?.let(Companion::parse)
override val headers: Headers = Headers.build {
appendAll(
proxiedResponseHeaders.filter { key, _ -> key.toLowerCase() !in excludedProxiedHeaders }
)
}
override val status: HttpStatusCode = proxiedResponse.status
override suspend fun writeTo(channel: ByteWriteChannel) {
proxiedResponse.content.copyAndClose(channel)
}
})
val endTime = System.currentTimeMillis()
val roundTripTime = endTime - startTime
log.debug("--> $requestUri ($requestMethodValue) -> $completeProxiedUrl (status: ${proxiedResponse.status}) (time: ${roundTripTime}ms (${proxyTime - startTime}ms - ${endTime - proxyTime}ms))")
}
val PipelineContext<*, ApplicationCall>.log
get() = application.log
package com.anifichadia.ktorreverseproxy
import io.ktor.application.ApplicationCall
import io.ktor.application.call
import io.ktor.client.HttpClient
import io.ktor.http.HttpStatusCode
import io.ktor.response.respond
import io.ktor.util.pipeline.PipelineContext
abstract class Fallback {
abstract suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient)
class RespondWithHttpStatusCode(private val httpStatusCode: HttpStatusCode) : Fallback() {
override suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient) {
pipelineContext.call.respond(httpStatusCode)
}
}
class ToUrl(private val redirectTo: String) : Fallback() {
override suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient) {
pipelineContext.proxyRequest(client, redirectTo)
}
}
}
package com.anifichadia.ktorreverseproxy
import io.ktor.request.ApplicationRequest
import io.ktor.request.path
import kotlin.text.RegexOption.IGNORE_CASE
abstract class Mapping(
val redirectTo: String,
) {
abstract fun matches(request: ApplicationRequest): Boolean
class SimpleMapping(
redirectTo: String,
private val matcher: String,
) : Mapping(redirectTo) {
override fun matches(request: ApplicationRequest): Boolean = request.path().startsWith(matcher)
}
class RegexMapping(
redirectTo: String,
private val matcher: Regex,
) : Mapping(redirectTo) {
constructor(redirectTo: String, matcherString: String) : this(redirectTo, matcherString.toRegex(IGNORE_CASE))
override fun matches(request: ApplicationRequest): Boolean = request.path() matches matcher
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment