Skip to content

Instantly share code, notes, and snippets.

@asad-awadia
Created June 1, 2026 01:18
Show Gist options
  • Select an option

  • Save asad-awadia/31233ddf5713b4b93d1277e46f79caad to your computer and use it in GitHub Desktop.

Select an option

Save asad-awadia/31233ddf5713b4b93d1277e46f79caad to your computer and use it in GitHub Desktop.
rocksdb-orm
import com.github.f4b6a3.ulid.UlidFactory
import org.rocksdb.*
import java.io.*
import java.nio.file.Path
import java.time.Instant
import kotlin.experimental.inv
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KProperty1
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.primaryConstructor
// ---------- core serializer ----------
interface BinarySerializer<T> {
fun serialize(v: T): ByteArray
fun deserialize(b: ByteArray): T
}
interface Field<T> {
fun write(out: DataOutput, v: T)
fun read(inp: DataInput): Any
}
class TypedCodec<T : Any>(
private val ctor: KFunction<T>,
private val fields: List<Field<T>>
) : BinarySerializer<T> {
override fun serialize(v: T): ByteArray {
val baos = ByteArrayOutputStream(64)
val out = DataOutputStream(baos)
fields.forEach { it.write(out, v) }
return baos.toByteArray()
}
override fun deserialize(b: ByteArray): T {
val inp = DataInputStream(ByteArrayInputStream(b))
val args = fields.map { it.read(inp) }.toTypedArray()
return ctor.call(*args)
}
}
class TypedBuilder<T : Any>(private val ctor: KFunction<T>) {
private val fields = mutableListOf<Field<T>>()
fun string(prop: KProperty1<T, String>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) {
val s = prop.get(v)
require(!s.contains('\u0000')) { "key strings cannot contain NUL" }
val b = s.toByteArray(Charsets.UTF_8)
out.write(b); out.writeByte(0)
}
override fun read(inp: DataInput): Any {
val buf = ByteArrayOutputStream()
while (true) { val c = inp.readByte().toInt(); if (c == 0) break; buf.write(c) }
return String(buf.toByteArray(), Charsets.UTF_8)
}
})
fun int(prop: KProperty1<T, Int>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) { out.writeInt(prop.get(v) xor Int.MIN_VALUE) }
override fun read(inp: DataInput) = inp.readInt() xor Int.MIN_VALUE
})
fun long(prop: KProperty1<T, Long>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) { out.writeLong(prop.get(v) xor Long.MIN_VALUE) }
override fun read(inp: DataInput) = inp.readLong() xor Long.MIN_VALUE
})
fun boolean(prop: KProperty1<T, Boolean>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) { out.writeByte(if (prop.get(v)) 1 else 0) }
override fun read(inp: DataInput) = inp.readByte().toInt()!= 0
})
fun float(prop: KProperty1<T, Float>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) {
val bits = java.lang.Float.floatToRawIntBits(prop.get(v))
out.writeInt(bits xor ((bits shr 31) and Int.MAX_VALUE) xor Int.MIN_VALUE)
}
override fun read(inp: DataInput): Any {
val s = inp.readInt()
val b2 = s xor Int.MIN_VALUE
return java.lang.Float.intBitsToFloat(b2 xor ((b2 shr 31) and Int.MAX_VALUE))
}
})
fun double(prop: KProperty1<T, Double>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) {
val bits = java.lang.Double.doubleToRawLongBits(prop.get(v))
out.writeLong(bits xor ((bits shr 63) and Long.MAX_VALUE) xor Long.MIN_VALUE)
}
override fun read(inp: DataInput): Any {
val s = inp.readLong()
val b2 = s xor Long.MIN_VALUE
return java.lang.Double.longBitsToDouble(b2 xor ((b2 shr 63) and Long.MAX_VALUE))
}
})
fun instant(prop: KProperty1<T, Instant>) = fields.add(object : Field<T> {
override fun write(out: DataOutput, v: T) {
val i = prop.get(v)
out.writeLong(i.epochSecond xor Long.MIN_VALUE); out.writeInt(i.nano)
}
override fun read(inp: DataInput): Any {
return Instant.ofEpochSecond(inp.readLong() xor Long.MIN_VALUE, inp.readInt().toLong())
}
})
fun build() = TypedCodec(ctor, fields)
}
fun <T : Any> orderedCodec(ctor: KFunction<T>, block: TypedBuilder<T>.() -> Unit) =
TypedBuilder(ctor).apply(block).build()
// ---------- auto codec ----------
inline fun <reified T : Any> autoCodec(): BinarySerializer<T> = autoCodec(T::class)
fun <T : Any> autoCodec(kClass: KClass<T>): BinarySerializer<T> {
val ctor = kClass.primaryConstructor!!
val fields = ctor.parameters.map { p ->
val prop = kClass.memberProperties.first { it.name == p.name } as KProperty1<T, *>
val desc = prop.findAnnotation<Desc>()!= null
val type = prop.returnType.classifier
object : Field<T> {
override fun write(out: DataOutput, v: T) {
when (type) {
String::class -> {
val s = prop.get(v) as String
require(!s.contains('\u0000')) { "key strings cannot contain NUL" }
val b = s.toByteArray(Charsets.UTF_8)
for (bb in b) out.writeByte(if (desc) bb.inv().toInt() else bb.toInt())
out.writeByte(if (desc) 0xFF else 0)
}
Int::class -> {
var x = (prop.get(v) as Int) xor Int.MIN_VALUE
if (desc) x = x.inv()
out.writeInt(x)
}
Long::class -> {
var x = (prop.get(v) as Long) xor Long.MIN_VALUE
if (desc) x = x.inv()
out.writeLong(x)
}
Instant::class -> {
val i = prop.get(v) as Instant
var sec = i.epochSecond xor Long.MIN_VALUE
var nano = i.nano
if (desc) { sec = sec.inv(); nano = nano.inv() }
out.writeLong(sec); out.writeInt(nano)
}
Boolean::class -> {
val b = prop.get(v) as Boolean
out.writeByte(if (desc xor b) 0 else 1)
}
Float::class -> {
val bits = java.lang.Float.floatToRawIntBits(prop.get(v) as Float)
var x = bits xor ((bits shr 31) and Int.MAX_VALUE) xor Int.MIN_VALUE
if (desc) x = x.inv()
out.writeInt(x)
}
Double::class -> {
val bits = java.lang.Double.doubleToRawLongBits(prop.get(v) as Double)
var x = bits xor ((bits shr 63) and Long.MAX_VALUE) xor Long.MIN_VALUE
if (desc) x = x.inv()
out.writeLong(x)
}
else -> error("unsupported $type")
}
}
override fun read(inp: DataInput): Any = when (type) {
String::class -> {
val buf = ByteArrayOutputStream()
while (true) {
var c = inp.readByte()
if (desc) c = c.inv()
if (c == 0.toByte()) break
buf.write(c.toInt())
}
String(buf.toByteArray(), Charsets.UTF_8)
}
Int::class -> {
var x = inp.readInt()
if (desc) x = x.inv()
x xor Int.MIN_VALUE
}
Long::class -> {
var x = inp.readLong()
if (desc) x = x.inv()
x xor Long.MIN_VALUE
}
Instant::class -> {
var sec = inp.readLong()
var nano = inp.readInt()
if (desc) { sec = sec.inv(); nano = nano.inv() }
Instant.ofEpochSecond(sec xor Long.MIN_VALUE, nano.toLong())
}
Boolean::class -> {
val v = inp.readByte().toInt()!= 0
if (desc)!v else v
}
Float::class -> {
var x = inp.readInt()
if (desc) x = x.inv()
val b2 = x xor Int.MIN_VALUE
java.lang.Float.intBitsToFloat(b2 xor ((b2 shr 31) and Int.MAX_VALUE))
}
Double::class -> {
var x = inp.readLong()
if (desc) x = x.inv()
val b2 = x xor Long.MIN_VALUE
java.lang.Double.longBitsToDouble(b2 xor ((b2 shr 63) and Long.MAX_VALUE))
}
else -> error("unsupported")
}
}
}
return TypedCodec(ctor, fields)
}
// ---------- RocksDB with transactions on table ----------
class Txn internal constructor(internal val tx: Transaction) : AutoCloseable {
fun commit() = tx.commit()
fun rollback() = tx.rollback()
override fun close() = tx.close()
}
class RocksDbDatabase(path: Path) : Closeable {
internal val db: TransactionDB
private val handles = mutableMapOf<String, ColumnFamilyHandle>()
init {
RocksDB.loadLibrary()
val dbOpts = DBOptions().setCreateIfMissing(true).setCreateMissingColumnFamilies(true)
val cfOpts = ColumnFamilyOptions().setComparator(BuiltinComparator.BYTEWISE_COMPARATOR)
val pathStr = path.toString()
val existing: List<ByteArray> = try { RocksDB.listColumnFamilies(Options(), pathStr) } catch (_: RocksDBException) { emptyList() }
val descriptors = if (existing.isEmpty()) listOf(ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, cfOpts))
else existing.map { name -> ColumnFamilyDescriptor(name, cfOpts) }
val hs = mutableListOf<ColumnFamilyHandle>()
db = TransactionDB.open(dbOpts, TransactionDBOptions(), pathStr, descriptors, hs)
descriptors.forEachIndexed { i, d -> handles[String(d.name)] = hs[i] }
}
fun <K : Any, V : Any> openTree(name: String, ks: BinarySerializer<K>, vs: BinarySerializer<V>): TreeTable<K, V> {
val h = handles.getOrPut(name) {
val cfo = ColumnFamilyOptions().setComparator(BuiltinComparator.BYTEWISE_COMPARATOR)
db.createColumnFamily(ColumnFamilyDescriptor(name.toByteArray(), cfo))
}
return TreeTable(db, h, ks, vs)
}
override fun close() { handles.values.forEach { it.close() }; db.close() }
}
class TreeTable<K : Any, V : Any>(
private val db: TransactionDB,
private val h: ColumnFamilyHandle,
private val ks: BinarySerializer<K>,
private val vs: BinarySerializer<V>
) {
private fun kb(k: K) = ks.serialize(k)
private fun decode(it: RocksIterator) = ks.deserialize(it.key()) to vs.deserialize(it.value())
fun begin(): Txn = Txn(db.beginTransaction(WriteOptions()))
fun <R> transaction(block: (Txn) -> R): R {
val tx = begin()
try { val r = block(tx); tx.commit(); return r } catch (e: Throwable) { tx.rollback(); throw e } finally { tx.close() }
}
fun get(k: K, txn: Txn? = null): V? {
val b = if (txn == null) db.get(h, kb(k)) else txn.tx.get(ReadOptions(), h, kb(k))
return b?.let(vs::deserialize)
}
fun getForUpdate(k: K, txn: Txn): V? {
val b = txn.tx.getForUpdate(ReadOptions(), h, kb(k), true)
return b?.let(vs::deserialize)
}
fun put(k: K, v: V, txn: Txn? = null) {
val key = kb(k); val value = vs.serialize(v)
if (txn == null) db.put(h, key, value) else txn.tx.put(h, key, value)
}
fun delete(k: K, txn: Txn? = null) {
val key = kb(k)
if (txn == null) db.delete(h, key) else txn.tx.delete(h, key)
}
fun contains(k: K, txn: Txn? = null) = get(k, txn)!= null
fun first(): Pair<K, V>? = db.newIterator(h).use { it.seekToFirst(); if (it.isValid) decode(it) else null }
fun last(): Pair<K, V>? = db.newIterator(h).use { it.seekToLast(); if (it.isValid) decode(it) else null }
fun ceiling(k: K): Pair<K, V>? = db.newIterator(h).use { it.seek(kb(k)); if (it.isValid) decode(it) else null }
fun higher(k: K): Pair<K, V>? = db.newIterator(h).use {
val key = kb(k); it.seek(key); if (!it.isValid) return@use null; if (it.key().contentEquals(key)) it.next(); if (it.isValid) decode(it) else null
}
fun floor(k: K): Pair<K, V>? = db.newIterator(h).use {
val key = kb(k); it.seek(key); if (!it.isValid) { it.seekToLast(); return@use if (it.isValid) decode(it) else null }; if (it.key().contentEquals(key)) return@use decode(it); it.prev(); if (it.isValid) decode(it) else null
}
fun lower(k: K): Pair<K, V>? = db.newIterator(h).use { it.seek(kb(k)); it.prev(); if (it.isValid) decode(it) else null }
fun scan(from: K? = null, fromInclusive: Boolean = true, to: K? = null, toInclusive: Boolean = false, limit: Int = Int.MAX_VALUE): Sequence<Pair<K, V>> = sequence {
db.newIterator(h).use { it ->
val fromB = from?.let(::kb); val toB = to?.let(::kb)
if (fromB == null) it.seekToFirst() else it.seek(fromB)
if (fromB!= null &&!fromInclusive && it.isValid && it.key().contentEquals(fromB)) it.next()
var n = 0
while (it.isValid && n < limit) {
val key = it.key()
if (toB!= null) { val cmp = key.compareTo(toB); if (cmp > 0 || (cmp == 0 &&!toInclusive)) break }
yield(decode(it)); it.next(); n++
}
}
}
fun scanPrefix(prefix: ByteArray, limit: Int = Int.MAX_VALUE): Sequence<Pair<K, V>> = sequence {
db.newIterator(h).use { it ->
it.seek(prefix); var n = 0
while (it.isValid && n < limit && it.key().startsWith(prefix)) { yield(decode(it)); it.next(); n++ }
}
}
private fun ByteArray.compareTo(o: ByteArray): Int {
val l = minOf(size, o.size)
for (i in 0 until l) { val d = (this[i].toInt() and 0xff) - (o[i].toInt() and 0xff); if (d!= 0) return d }
return size - o.size
}
private fun ByteArray.startsWith(p: ByteArray): Boolean {
if (size < p.size) return false
for (i in p.indices) if (this[i]!= p[i]) return false
return true
}
}
// ---------- example usage ----------
data class BugReport(val title: String, val priority: Int, val created: Instant, val id: String = UlidFactory.newInstance().create().toString())
data class BugReport2(val priority: Int, val created: Instant, val id: String = UlidFactory.newInstance().create().toString())
data class BugReportByPriorityIndex(val priority: Int, @Desc val created: Instant)
class Versioned<V : Any>(private val current: Int, private val codecs: Map<Int, BinarySerializer<V>>, private val migrate: (Any) -> V) : BinarySerializer<V> {
override fun serialize(v: V): ByteArray {
val body = codecs[current]!!.serialize(v)
return byteArrayOf(current.toByte()) + body
}
override fun deserialize(b: ByteArray): V {
val ver = b[0].toInt()
val body = b.copyOfRange(1, b.size)
val old = codecs[ver]!!.deserialize(body)
return if (ver == current) old else migrate(old)
}
}
@Target(AnnotationTarget.PROPERTY)
@Retention(AnnotationRetention.RUNTIME)
annotation class Desc
fun main() {
val db = RocksDbDatabase(Path.of("./data"))
val bugPK = autoCodec<BugReportByPriorityIndex>()
val v1 = autoCodec<BugReport>()
val bugsDB: TreeTable<BugReportByPriorityIndex, BugReport> = db.openTree("bugs_by_priority", bugPK, v1)
val key = BugReportByPriorityIndex(1, Instant.now())
bugsDB.put(key, BugReport("jvm getting npe", 1, Instant.now()))
bugsDB.put(BugReportByPriorityIndex(1, Instant.now()), BugReport("highest priorty", 1, Instant.now()))
println(bugsDB.first())
println(bugsDB.last())
db.close()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment