Skip to content

Instantly share code, notes, and snippets.

@DmitryOlshansky
Created April 25, 2020 13:23
Show Gist options
  • Save DmitryOlshansky/7b1b5a449c559a5253ce4789bba32b86 to your computer and use it in GitHub Desktop.
Save DmitryOlshansky/7b1b5a449c559a5253ce4789bba32b86 to your computer and use it in GitHub Desktop.
A simple bulk uploader script with SMILE format support
package me.olshansky
import com.fasterxml.jackson.core.JsonFactory
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.smile.SmileFactory
import com.fasterxml.jackson.module.kotlin.KotlinModule
import org.http4k.client.ApacheClient
import org.http4k.core.*
import org.http4k.core.Method.*
import java.io.ByteArrayOutputStream
import java.io.FileOutputStream
import java.lang.Exception
import java.net.URI
import java.nio.ByteBuffer
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import com.github.ajalt.clikt.core.CliktCommand
import com.github.ajalt.clikt.core.subcommands
import com.github.ajalt.clikt.parameters.arguments.argument
import com.github.ajalt.clikt.parameters.arguments.convert
import com.github.ajalt.clikt.parameters.arguments.multiple
import com.github.ajalt.clikt.parameters.options.default
import com.github.ajalt.clikt.parameters.options.flag
import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.options.validate
import com.github.ajalt.clikt.parameters.types.file
import com.github.ajalt.clikt.parameters.types.int
import java.util.*
data class BulkStatus(val result: String)
data class BulkOp(val index: BulkStatus)
data class BulkResults(val took: Int, val errors: Boolean, val items: List<BulkOp>)
class ElasticTools: CliktCommand() {
override fun run() = Unit
}
class UploadCommand : CliktCommand(name="upload") {
val input by argument(help="input .json or .smile files").file(mustExist = true)//.multiple()
val target by argument(help="uri of elasticsearch to use").convert { URI(it) }
val size by option(help="size of bulk request").int().default(10000).validate { it > 0 }
val threads by option(help="number of concurrent threads").int().default(10).validate { it in 1..10000 }
fun deleteIndex(client: HttpHandler) {
val request = Request(DELETE, "$target")
val response = client(request)
require(response.status == Status.OK || response.status == Status.NOT_FOUND) { "Couldn't delete preexisting $target index" }
}
fun createIndex(client: HttpHandler) {
val request = Request(PUT, "$target")
val response = client(request)
require(response.status == Status.OK) { "Couldn't create $target index" }
}
fun insertBulk(client: HttpHandler, data: ByteArray, mapper: ObjectMapper, contentType: String) {
val body = Body(ByteBuffer.wrap(data))
val request = Request(POST, "${target.scheme}://${target.host}:${target.port}/_bulk")
.body(body)
.header("Content-type", contentType)
val response = client(request)
require(response.status == Status.OK) { "Bulk indexing failed" }
val statuses = mapper.readValue(response.body.stream, BulkResults::class.java)
if (statuses.errors) echo(response.bodyString())
}
fun setRefresh(client: HttpHandler, value: String) {
val body = """
{
"index" : {
"refresh_interval" : "$value"
}
}
""".trimIndent()
val request = Request(PUT, "$target/_settings").body(body).header("Content-type", "application/json")
require(client(request).status == Status.OK) { "failed to disable refresh on index"}
}
fun refresh(client: HttpHandler) {
val request = Request(POST, "$target/_refresh")
require(client(request).status == Status.OK) { "Failed while refreshing index"}
}
override fun run() {
val factory: JsonFactory
val separator: Int
val contentType: String
if(input.extension == "json") {
factory = JsonFactory.builder().build()
separator = '\n'.toInt()
contentType = "application/json"
}
else {
factory = SmileFactory.builder().build()
separator = 0xff
contentType = "application/smile"
}
val mapper = ObjectMapper(factory)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.registerModule(KotlinModule())
val index = target.path.substring(1)
if (index.contains('/')) {
echo("Target must be of form http(s)://hostname:port/index-name")
System.exit(1)
}
val header = ByteArrayOutputStream().apply {
mapper.writeValue(this, mapOf("index" to mapOf("_index" to index)) )
write(separator)
}.toByteArray()
val mainClient = ApacheClient()
deleteIndex(mainClient)
createIndex(mainClient)
setRefresh(mainClient, "-1")
val start = System.nanoTime()
val executor = Executors.newFixedThreadPool(threads)
val queue = ArrayBlockingQueue<ByteArray>(threads*2) // up to x2 batches in-flight
val terminated = AtomicBoolean(false)
for (t in 1..threads)
executor.submit {
val client = ApacheClient()
try {
while(!terminated.get()){
val item = queue.poll(10, TimeUnit.MILLISECONDS)
if (item != null)
insertBulk(client, item, mapper, contentType)
}
}
catch (e: Exception) {
echo(e)
terminated.set(true)
}
}
val bulk = ByteArrayOutputStream()
val mb = 1024*1024
val buffer = ByteArray(mb) // must be > then any full document
input.inputStream().use { source ->
val fileSize = input.length()
var total = 0
var totalBytes = 0L
var leftover = 0
while (true) {
val read = source.read(buffer, leftover, buffer.size - leftover)
totalBytes += read
echo("%.2f / %.2f Mb".format(totalBytes / mb.toDouble(), fileSize / mb.toDouble()))
if (read <= 0) break
val available = read + leftover
var i = 0
var processed = 0
while (i < available) {
if (buffer[i] == separator.toByte()) {
bulk.writeBytes(header)
bulk.write(buffer, processed, i - processed + 1) // include separator at i
total++
if (total % size == 0) {
queue.put(bulk.toByteArray())
bulk.reset()
}
processed = i + 1 // skip over separator
}
i++
}
leftover = available - processed
//shift buffer
buffer.copyInto(buffer, 0, processed, available)
}
if(bulk.size() > 0) queue.put(bulk.toByteArray())
}
while (queue.isNotEmpty()) Thread.sleep(1)
terminated.set(true)
executor.shutdown()
refresh(mainClient)
setRefresh(mainClient, "5s")
val end = System.nanoTime()
val ms = Math.round((end - start)/1e6)
echo("Total time: $ms ms")
}
}
class Json2Smile : CliktCommand(name="json2smile") {
val from by argument(help="source .json files").file(mustExist = true).multiple()
val to by argument(help="destination .smile file").file(mustExist = false)
override fun run() {
val smile = SmileFactory()
val json = JsonFactory()
val jsonMapper = ObjectMapper(json)
val smileMapper = ObjectMapper(smile)
FileOutputStream(to).use { out ->
val buffer = ByteArrayOutputStream()
for (f in from) {
f.forEachLine { line ->
buffer.reset()
smile.createGenerator(buffer).use { gen ->
smileMapper.writeTree(gen, jsonMapper.readTree(line))
}
buffer.write(0xff)
out.write(buffer.toByteArray())
}
}
}
}
}
fun main(args: Array<String>) = ElasticTools()
.subcommands(Json2Smile(), UploadCommand())
.main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment