Created
October 25, 2022 19:35
-
-
Save rockwotj/df6801745cb309d8a958abf201d9e3e3 to your computer and use it in GitHub Desktop.
A Native JVM transport for elasticsearch using the builtin JDK HttpClient instead of apache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private class NativeJvmTransport(private val config: SearchClient.Config) : ElasticsearchTransport { | |
private val options = Options( | |
headers = ImmutableMultimap.builder<String, String>().apply { | |
val mimeType = Version.VERSION?.let { version -> "application/vnd.elasticsearch+json; compatible-with=${version.major()}" } ?: "application/json" | |
put("Content-Type", mimeType) | |
put("Accept", mimeType) | |
config.apiKey?.let { put("Authorization", "ApiKey $it") } | |
}.build() | |
) | |
private val mapper = JacksonJsonpMapper() | |
private val client = HttpClient.newBuilder().apply { | |
connectTimeout(Duration.ofSeconds(1)) | |
executor(Executors.newCachedThreadPool()) | |
version(HttpClient.Version.HTTP_1_1) | |
config.sslContext()?.let(this::sslContext) | |
}.build() | |
override fun <RequestT : Any?, ResponseT : Any?, ErrorT : Any?> performRequest( | |
request: RequestT, | |
endpoint: Endpoint<RequestT, ResponseT, ErrorT>, | |
options: TransportOptions? | |
): ResponseT? { | |
return this.performRequestAsync(request, endpoint, options).get() | |
} | |
override fun <RequestT : Any?, ResponseT : Any?, ErrorT : Any?> performRequestAsync( | |
request: RequestT, | |
endpoint: Endpoint<RequestT, ResponseT, ErrorT>, | |
options: TransportOptions? | |
): CompletableFuture<ResponseT?> { | |
val opts = options ?: this.options | |
val req = HttpRequest.newBuilder().apply { | |
uri( | |
URIBuilder().apply { | |
charset = Charsets.UTF_8 | |
scheme = if (config.host == "localhost") "http" else "https" | |
host = config.host | |
port = config.port | |
path = endpoint.requestUrl(request) | |
for ((name, value) in opts.queryParameters()) { | |
addParameter(name, value) | |
} | |
for ((name, value) in endpoint.queryParameters(request)) { | |
addParameter(name, value) | |
} | |
}.build() | |
) | |
for ((name, value) in opts.headers()) { | |
header(name, value) | |
} | |
val body = if (endpoint.hasRequestBody()) { | |
asBodyPublisher(request) | |
} else { | |
BodyPublishers.noBody() | |
} | |
method(endpoint.method(request), body) | |
// Give a long timeout for slow delete requests. | |
// TODO: Can we set this from the request? | |
timeout(Duration.ofMinutes(1)) | |
}.build() | |
println("URI: ${req.uri()}") | |
return this.client.sendAsync(req, BodyHandlers.ofInputStream()) | |
.thenApply { resp -> processResponse(resp, endpoint) } | |
} | |
override fun jsonpMapper(): JsonpMapper = mapper | |
override fun options(): TransportOptions = options | |
override fun close() { | |
// Nothing to do | |
} | |
private fun <ResponseT : Any?> processResponse( | |
resp: HttpResponse<InputStream>, | |
endpoint: Endpoint<*, ResponseT, *> | |
) = try { | |
val statusCode = resp.statusCode() | |
if (endpoint.isError(statusCode)) { | |
val errorDeserializer = endpoint.errorDeserializer(statusCode) | |
?: throw TransportException("Request failed with status code '$statusCode'", endpoint.id()) | |
// We may need to replay the body stream | |
val buffer = BufferedInputStream(resp.body()) | |
buffer.mark(/*bufferlimit=*/MEGABYTE) | |
try { | |
mapper.jsonProvider().createParser(buffer).use { parser -> | |
val error = errorDeserializer.deserialize(parser, mapper) | |
throw ElasticsearchException(endpoint.id(), error as ErrorResponse) | |
} | |
} catch (errorEx: MissingRequiredPropertyException) { | |
// Could not decode exception, try the response type | |
try { | |
buffer.reset() | |
decodeResponse(statusCode, buffer, endpoint) | |
} catch (respEx: Exception) { | |
// No better luck: throw the original error decoding exception | |
throw TransportException( | |
"Failed to decode error response with status code: '$statusCode'", | |
endpoint.id(), | |
respEx | |
) | |
} | |
} | |
} else { | |
decodeResponse(statusCode, resp.body(), endpoint) | |
} | |
} finally { | |
Closeables.closeQuietly(resp.body()) | |
} | |
private fun <ResponseT : Any?> decodeResponse(statusCode: Int, body: InputStream, endpoint: Endpoint<*, ResponseT, *>): ResponseT? { | |
return when (endpoint) { | |
is BooleanEndpoint -> { | |
@Suppress("UNCHECKED_CAST") | |
return BooleanResponse(endpoint.getResult(statusCode)) as ResponseT? | |
} | |
is JsonEndpoint -> { | |
endpoint.responseDeserializer()?.let { deserializer -> | |
mapper.jsonProvider().createParser(body).use { parser -> | |
deserializer.deserialize(parser, mapper) | |
} | |
} | |
} | |
else -> throw TransportException("Unhandled endpoint type: ${endpoint.javaClass.name}", endpoint.id()) | |
} | |
} | |
// Request has a body and must implement JsonpSerializable or NdJsonpSerializable | |
private fun <RequestT : Any?> asBodyPublisher(request: RequestT): HttpRequest.BodyPublisher { | |
val baos = ByteArrayOutputStream() | |
if (request is NdJsonpSerializable) { | |
writeNdJson(request, baos) | |
} else { | |
val generator = mapper.jsonProvider().createGenerator(baos) | |
mapper.serialize(request, generator) | |
generator.close() | |
} | |
println(baos.toString(Charsets.UTF_8)) | |
return BodyPublishers.ofByteArray(baos.toByteArray()) | |
} | |
private fun writeNdJson(value: NdJsonpSerializable, baos: ByteArrayOutputStream) { | |
val values = value._serializables() | |
while (values.hasNext()) { | |
val item = values.next() | |
if (item is NdJsonpSerializable && item !== value) { // do not recurse on the item itself | |
writeNdJson(item, baos) | |
} else { | |
val generator = mapper.jsonProvider().createGenerator(baos) | |
mapper.serialize(item, generator) | |
generator.close() | |
baos.write('\n'.code) | |
} | |
} | |
} | |
private class Options( | |
val headers: ImmutableMultimap<String, String>, | |
val queryParams: Map<String, String> = emptyMap(), | |
val onWarnings: Function<List<String>, Boolean> = Function { false }, | |
) : TransportOptions { | |
override fun headers(): Collection<Map.Entry<String, String>> = headers.entries() | |
override fun queryParameters(): Map<String, String> = queryParams | |
override fun onWarnings(): Function<List<String>, Boolean> = onWarnings | |
override fun toBuilder(): TransportOptions.Builder { | |
val b = Builder() | |
b.headers.putAll(headers) | |
b.queryParams.putAll(queryParams) | |
b.onWarnings = onWarnings | |
return b | |
} | |
private class Builder : TransportOptions.Builder { | |
val headers: Multimap<String, String> = HashMultimap.create() | |
val queryParams: MutableMap<String, String> = mutableMapOf() | |
var onWarnings: Function<List<String>, Boolean> = Function { false } | |
override fun addHeader(name: String, value: String): TransportOptions.Builder { | |
headers.put(name, value) | |
return this | |
} | |
override fun setParameter(name: String, value: String): TransportOptions.Builder { | |
queryParams[name] = value | |
return this | |
} | |
override fun onWarnings(listener: Function<List<String>, Boolean>): TransportOptions.Builder { | |
onWarnings = listener | |
return this | |
} | |
override fun build(): TransportOptions { | |
return Options(ImmutableMultimap.copyOf(headers), queryParams.toMap(), onWarnings) | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment