-
-
Save Vaysman/48ce5b3732576b62fedfc3d99dd3877f to your computer and use it in GitHub Desktop.
Ktor reverse proxy that forwards the entire request contents (including the body). This is based off https://github.com/ktorio/ktor-samples/blob/1.3.0/other/reverse-proxy/src/ReverseProxyApplication.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.anifichadia.ktorreverseproxy | |
import io.ktor.application.Application | |
import io.ktor.application.ApplicationCallPipeline | |
import io.ktor.application.ApplicationStopping | |
import io.ktor.application.call | |
import io.ktor.application.install | |
import io.ktor.application.log | |
import io.ktor.client.HttpClient | |
import io.ktor.client.features.HttpTimeout | |
import io.ktor.features.CallLogging | |
import io.ktor.request.httpMethod | |
import io.ktor.request.uri | |
import org.slf4j.event.Level | |
import java.util.concurrent.TimeUnit | |
/** | |
* Based off: https://github.com/ktorio/ktor-samples/blob/1.3.0/other/reverse-proxy/src/ReverseProxyApplication.kt | |
* Contains a few changes to forward the request body also | |
*/ | |
class KtorReverseProxy( | |
val fallback: Fallback, | |
val mappings: List<Mapping> = emptyList(), | |
) { | |
fun configure(application: Application) = application.apply { | |
val client = HttpClient { | |
// Took me a good couple of hours to figure out why error responses weren't being forwarded. HttpClient's | |
// default `expectSuccess` to true. When it is true, the HttpClient pipeline response validation step throws | |
// an exception instead of consuming the response. For a gateway, we're expecting potential error responses, | |
// which we want to consume and forward to clients. | |
expectSuccess = false | |
install(HttpTimeout) { | |
val timeoutMillis = TimeUnit.SECONDS.toMillis(60) | |
requestTimeoutMillis = timeoutMillis | |
connectTimeoutMillis = timeoutMillis | |
socketTimeoutMillis = timeoutMillis | |
} | |
install(CallLogging) { | |
logger = log | |
level = Level.TRACE | |
} | |
}.apply { | |
environment.monitor.subscribe(ApplicationStopping) { | |
close() | |
} | |
} | |
intercept(ApplicationCallPipeline.Call) { | |
val requestUri = call.request.uri | |
val requestMethod = call.request.httpMethod | |
val requestMethodValue = requestMethod.value | |
val receiveChannel = call.request.receiveChannel() | |
val hasBody = receiveChannel.availableForRead > 0 | |
log.debug("<-- $requestUri ($requestMethodValue, body: $hasBody)") | |
val mapping = mappings.find { it.matches(call.request) } | |
if (mapping != null) { | |
proxyRequest(client, mapping.redirectTo) | |
} else { | |
log.debug("$requestUri: Using fallback") | |
fallback.handle(this, client) | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.anifichadia.ktorreverseproxy | |
import com.anifichadia.ktorreverseproxy.Mapping.RegexMapping | |
import com.anifichadia.ktorreverseproxy.Mapping.SimpleMapping | |
import io.ktor.server.engine.embeddedServer | |
import io.ktor.server.netty.Netty | |
object KtorReverseProxyHost { | |
@JvmStatic | |
fun main(args: Array<String>) { | |
embeddedServer( | |
Netty, | |
port = 8080 | |
) { | |
KtorReverseProxy( | |
fallback = Fallback.ToUrl("http://localhost:8081"), | |
mappings = listOf( | |
SimpleMapping( | |
redirectTo = "http://localhost:8082", | |
matcherString = "/my/specific/route" | |
), | |
RegexMapping( | |
redirectTo = "http://localhost:8083", | |
matcherString = """/my/regex/matched/route.*""" | |
), | |
), | |
).configure(this) | |
}.start(wait = true) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.anifichadia.ktorreverseproxy | |
import io.ktor.application.ApplicationCall | |
import io.ktor.application.application | |
import io.ktor.application.call | |
import io.ktor.application.log | |
import io.ktor.client.HttpClient | |
import io.ktor.client.features.ResponseException | |
import io.ktor.client.request.header | |
import io.ktor.client.request.request | |
import io.ktor.client.statement.HttpResponse | |
import io.ktor.http.ContentType | |
import io.ktor.http.ContentType.Companion | |
import io.ktor.http.Headers | |
import io.ktor.http.HttpHeaders | |
import io.ktor.http.HttpStatusCode | |
import io.ktor.http.content.OutgoingContent.ReadChannelContent | |
import io.ktor.http.content.OutgoingContent.WriteChannelContent | |
import io.ktor.request.httpMethod | |
import io.ktor.request.uri | |
import io.ktor.response.respond | |
import io.ktor.util.filter | |
import io.ktor.util.pipeline.PipelineContext | |
import io.ktor.utils.io.ByteReadChannel | |
import io.ktor.utils.io.ByteWriteChannel | |
import io.ktor.utils.io.copyAndClose | |
/** Lower-cased headers that should not be proxied to the mappings or client */ | |
val excludedProxiedHeaders = listOf( | |
HttpHeaders.ContentType, | |
HttpHeaders.ContentLength, | |
HttpHeaders.TransferEncoding, | |
).map { it.toLowerCase() } | |
suspend fun PipelineContext<Unit, ApplicationCall>.proxyRequest(client: HttpClient, redirectTo: String) { | |
val startTime = System.currentTimeMillis() | |
val requestUri = call.request.uri | |
val requestMethod = call.request.httpMethod | |
val requestMethodValue = requestMethod.value | |
val requestHeaders = call.request.headers | |
val completeProxiedUrl = "$redirectTo$requestUri" | |
log.debug("$requestUri ($requestMethodValue) -> $completeProxiedUrl") | |
val proxiedResponse: HttpResponse = try { | |
val res = client.request<HttpResponse>(completeProxiedUrl) { | |
// Proxy the original request method to mapping to the mapped endpoint | |
method = requestMethod | |
// Proxy the original request headers (minus the excluded) to the mapped endpoint | |
requestHeaders | |
.filter { key, _ -> key.toLowerCase() !in excludedProxiedHeaders } | |
.forEach { key, value -> header(key, value.joinToString()) } | |
// Proxy the original request body to mapping to the mapped endpoint | |
body = object : ReadChannelContent() { | |
override val contentType: ContentType? = requestHeaders[HttpHeaders.ContentType]?.let(ContentType.Companion::parse) | |
override fun readFrom(): ByteReadChannel = call.request.receiveChannel() | |
} | |
} | |
res | |
} catch (e: ResponseException) { | |
e.response | |
} catch (e: Throwable) { | |
log.error("--> $requestUri ($requestMethodValue)", e) | |
throw e | |
} | |
val proxyTime = System.currentTimeMillis() | |
val proxiedResponseHeaders = proxiedResponse.headers | |
call.respond(object : WriteChannelContent() { | |
override val contentLength: Long? = proxiedResponseHeaders[HttpHeaders.ContentLength]?.toLong() | |
override val contentType: ContentType? = proxiedResponseHeaders[HttpHeaders.ContentType]?.let(Companion::parse) | |
override val headers: Headers = Headers.build { | |
appendAll( | |
proxiedResponseHeaders.filter { key, _ -> key.toLowerCase() !in excludedProxiedHeaders } | |
) | |
} | |
override val status: HttpStatusCode = proxiedResponse.status | |
override suspend fun writeTo(channel: ByteWriteChannel) { | |
proxiedResponse.content.copyAndClose(channel) | |
} | |
}) | |
val endTime = System.currentTimeMillis() | |
val roundTripTime = endTime - startTime | |
log.debug("--> $requestUri ($requestMethodValue) -> $completeProxiedUrl (status: ${proxiedResponse.status}) (time: ${roundTripTime}ms (${proxyTime - startTime}ms - ${endTime - proxyTime}ms))") | |
} | |
val PipelineContext<*, ApplicationCall>.log | |
get() = application.log |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.anifichadia.ktorreverseproxy | |
import io.ktor.application.ApplicationCall | |
import io.ktor.application.call | |
import io.ktor.client.HttpClient | |
import io.ktor.http.HttpStatusCode | |
import io.ktor.response.respond | |
import io.ktor.util.pipeline.PipelineContext | |
abstract class Fallback { | |
abstract suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient) | |
class RespondWithHttpStatusCode(private val httpStatusCode: HttpStatusCode) : Fallback() { | |
override suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient) { | |
pipelineContext.call.respond(httpStatusCode) | |
} | |
} | |
class ToUrl(private val redirectTo: String) : Fallback() { | |
override suspend fun handle(pipelineContext: PipelineContext<Unit, ApplicationCall>, client: HttpClient) { | |
pipelineContext.proxyRequest(client, redirectTo) | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.anifichadia.ktorreverseproxy | |
import io.ktor.request.ApplicationRequest | |
import io.ktor.request.path | |
import kotlin.text.RegexOption.IGNORE_CASE | |
abstract class Mapping( | |
val redirectTo: String, | |
) { | |
abstract fun matches(request: ApplicationRequest): Boolean | |
class SimpleMapping( | |
redirectTo: String, | |
private val matcher: String, | |
) : Mapping(redirectTo) { | |
override fun matches(request: ApplicationRequest): Boolean = request.path().startsWith(matcher) | |
} | |
class RegexMapping( | |
redirectTo: String, | |
private val matcher: Regex, | |
) : Mapping(redirectTo) { | |
constructor(redirectTo: String, matcherString: String) : this(redirectTo, matcherString.toRegex(IGNORE_CASE)) | |
override fun matches(request: ApplicationRequest): Boolean = request.path() matches matcher | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment