Created
June 1, 2026 01:18
-
-
Save asad-awadia/31233ddf5713b4b93d1277e46f79caad to your computer and use it in GitHub Desktop.
rocksdb-orm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import com.github.f4b6a3.ulid.UlidFactory | |
| import org.rocksdb.* | |
| import java.io.* | |
| import java.nio.file.Path | |
| import java.time.Instant | |
| import kotlin.experimental.inv | |
| import kotlin.reflect.KClass | |
| import kotlin.reflect.KFunction | |
| import kotlin.reflect.KProperty1 | |
| import kotlin.reflect.full.findAnnotation | |
| import kotlin.reflect.full.memberProperties | |
| import kotlin.reflect.full.primaryConstructor | |
| // ---------- core serializer ---------- | |
| interface BinarySerializer<T> { | |
| fun serialize(v: T): ByteArray | |
| fun deserialize(b: ByteArray): T | |
| } | |
| interface Field<T> { | |
| fun write(out: DataOutput, v: T) | |
| fun read(inp: DataInput): Any | |
| } | |
| class TypedCodec<T : Any>( | |
| private val ctor: KFunction<T>, | |
| private val fields: List<Field<T>> | |
| ) : BinarySerializer<T> { | |
| override fun serialize(v: T): ByteArray { | |
| val baos = ByteArrayOutputStream(64) | |
| val out = DataOutputStream(baos) | |
| fields.forEach { it.write(out, v) } | |
| return baos.toByteArray() | |
| } | |
| override fun deserialize(b: ByteArray): T { | |
| val inp = DataInputStream(ByteArrayInputStream(b)) | |
| val args = fields.map { it.read(inp) }.toTypedArray() | |
| return ctor.call(*args) | |
| } | |
| } | |
| class TypedBuilder<T : Any>(private val ctor: KFunction<T>) { | |
| private val fields = mutableListOf<Field<T>>() | |
| fun string(prop: KProperty1<T, String>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { | |
| val s = prop.get(v) | |
| require(!s.contains('\u0000')) { "key strings cannot contain NUL" } | |
| val b = s.toByteArray(Charsets.UTF_8) | |
| out.write(b); out.writeByte(0) | |
| } | |
| override fun read(inp: DataInput): Any { | |
| val buf = ByteArrayOutputStream() | |
| while (true) { val c = inp.readByte().toInt(); if (c == 0) break; buf.write(c) } | |
| return String(buf.toByteArray(), Charsets.UTF_8) | |
| } | |
| }) | |
| fun int(prop: KProperty1<T, Int>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { out.writeInt(prop.get(v) xor Int.MIN_VALUE) } | |
| override fun read(inp: DataInput) = inp.readInt() xor Int.MIN_VALUE | |
| }) | |
| fun long(prop: KProperty1<T, Long>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { out.writeLong(prop.get(v) xor Long.MIN_VALUE) } | |
| override fun read(inp: DataInput) = inp.readLong() xor Long.MIN_VALUE | |
| }) | |
| fun boolean(prop: KProperty1<T, Boolean>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { out.writeByte(if (prop.get(v)) 1 else 0) } | |
| override fun read(inp: DataInput) = inp.readByte().toInt()!= 0 | |
| }) | |
| fun float(prop: KProperty1<T, Float>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { | |
| val bits = java.lang.Float.floatToRawIntBits(prop.get(v)) | |
| out.writeInt(bits xor ((bits shr 31) and Int.MAX_VALUE) xor Int.MIN_VALUE) | |
| } | |
| override fun read(inp: DataInput): Any { | |
| val s = inp.readInt() | |
| val b2 = s xor Int.MIN_VALUE | |
| return java.lang.Float.intBitsToFloat(b2 xor ((b2 shr 31) and Int.MAX_VALUE)) | |
| } | |
| }) | |
| fun double(prop: KProperty1<T, Double>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { | |
| val bits = java.lang.Double.doubleToRawLongBits(prop.get(v)) | |
| out.writeLong(bits xor ((bits shr 63) and Long.MAX_VALUE) xor Long.MIN_VALUE) | |
| } | |
| override fun read(inp: DataInput): Any { | |
| val s = inp.readLong() | |
| val b2 = s xor Long.MIN_VALUE | |
| return java.lang.Double.longBitsToDouble(b2 xor ((b2 shr 63) and Long.MAX_VALUE)) | |
| } | |
| }) | |
| fun instant(prop: KProperty1<T, Instant>) = fields.add(object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { | |
| val i = prop.get(v) | |
| out.writeLong(i.epochSecond xor Long.MIN_VALUE); out.writeInt(i.nano) | |
| } | |
| override fun read(inp: DataInput): Any { | |
| return Instant.ofEpochSecond(inp.readLong() xor Long.MIN_VALUE, inp.readInt().toLong()) | |
| } | |
| }) | |
| fun build() = TypedCodec(ctor, fields) | |
| } | |
| fun <T : Any> orderedCodec(ctor: KFunction<T>, block: TypedBuilder<T>.() -> Unit) = | |
| TypedBuilder(ctor).apply(block).build() | |
| // ---------- auto codec ---------- | |
| inline fun <reified T : Any> autoCodec(): BinarySerializer<T> = autoCodec(T::class) | |
| fun <T : Any> autoCodec(kClass: KClass<T>): BinarySerializer<T> { | |
| val ctor = kClass.primaryConstructor!! | |
| val fields = ctor.parameters.map { p -> | |
| val prop = kClass.memberProperties.first { it.name == p.name } as KProperty1<T, *> | |
| val desc = prop.findAnnotation<Desc>()!= null | |
| val type = prop.returnType.classifier | |
| object : Field<T> { | |
| override fun write(out: DataOutput, v: T) { | |
| when (type) { | |
| String::class -> { | |
| val s = prop.get(v) as String | |
| require(!s.contains('\u0000')) { "key strings cannot contain NUL" } | |
| val b = s.toByteArray(Charsets.UTF_8) | |
| for (bb in b) out.writeByte(if (desc) bb.inv().toInt() else bb.toInt()) | |
| out.writeByte(if (desc) 0xFF else 0) | |
| } | |
| Int::class -> { | |
| var x = (prop.get(v) as Int) xor Int.MIN_VALUE | |
| if (desc) x = x.inv() | |
| out.writeInt(x) | |
| } | |
| Long::class -> { | |
| var x = (prop.get(v) as Long) xor Long.MIN_VALUE | |
| if (desc) x = x.inv() | |
| out.writeLong(x) | |
| } | |
| Instant::class -> { | |
| val i = prop.get(v) as Instant | |
| var sec = i.epochSecond xor Long.MIN_VALUE | |
| var nano = i.nano | |
| if (desc) { sec = sec.inv(); nano = nano.inv() } | |
| out.writeLong(sec); out.writeInt(nano) | |
| } | |
| Boolean::class -> { | |
| val b = prop.get(v) as Boolean | |
| out.writeByte(if (desc xor b) 0 else 1) | |
| } | |
| Float::class -> { | |
| val bits = java.lang.Float.floatToRawIntBits(prop.get(v) as Float) | |
| var x = bits xor ((bits shr 31) and Int.MAX_VALUE) xor Int.MIN_VALUE | |
| if (desc) x = x.inv() | |
| out.writeInt(x) | |
| } | |
| Double::class -> { | |
| val bits = java.lang.Double.doubleToRawLongBits(prop.get(v) as Double) | |
| var x = bits xor ((bits shr 63) and Long.MAX_VALUE) xor Long.MIN_VALUE | |
| if (desc) x = x.inv() | |
| out.writeLong(x) | |
| } | |
| else -> error("unsupported $type") | |
| } | |
| } | |
| override fun read(inp: DataInput): Any = when (type) { | |
| String::class -> { | |
| val buf = ByteArrayOutputStream() | |
| while (true) { | |
| var c = inp.readByte() | |
| if (desc) c = c.inv() | |
| if (c == 0.toByte()) break | |
| buf.write(c.toInt()) | |
| } | |
| String(buf.toByteArray(), Charsets.UTF_8) | |
| } | |
| Int::class -> { | |
| var x = inp.readInt() | |
| if (desc) x = x.inv() | |
| x xor Int.MIN_VALUE | |
| } | |
| Long::class -> { | |
| var x = inp.readLong() | |
| if (desc) x = x.inv() | |
| x xor Long.MIN_VALUE | |
| } | |
| Instant::class -> { | |
| var sec = inp.readLong() | |
| var nano = inp.readInt() | |
| if (desc) { sec = sec.inv(); nano = nano.inv() } | |
| Instant.ofEpochSecond(sec xor Long.MIN_VALUE, nano.toLong()) | |
| } | |
| Boolean::class -> { | |
| val v = inp.readByte().toInt()!= 0 | |
| if (desc)!v else v | |
| } | |
| Float::class -> { | |
| var x = inp.readInt() | |
| if (desc) x = x.inv() | |
| val b2 = x xor Int.MIN_VALUE | |
| java.lang.Float.intBitsToFloat(b2 xor ((b2 shr 31) and Int.MAX_VALUE)) | |
| } | |
| Double::class -> { | |
| var x = inp.readLong() | |
| if (desc) x = x.inv() | |
| val b2 = x xor Long.MIN_VALUE | |
| java.lang.Double.longBitsToDouble(b2 xor ((b2 shr 63) and Long.MAX_VALUE)) | |
| } | |
| else -> error("unsupported") | |
| } | |
| } | |
| } | |
| return TypedCodec(ctor, fields) | |
| } | |
| // ---------- RocksDB with transactions on table ---------- | |
| class Txn internal constructor(internal val tx: Transaction) : AutoCloseable { | |
| fun commit() = tx.commit() | |
| fun rollback() = tx.rollback() | |
| override fun close() = tx.close() | |
| } | |
| class RocksDbDatabase(path: Path) : Closeable { | |
| internal val db: TransactionDB | |
| private val handles = mutableMapOf<String, ColumnFamilyHandle>() | |
| init { | |
| RocksDB.loadLibrary() | |
| val dbOpts = DBOptions().setCreateIfMissing(true).setCreateMissingColumnFamilies(true) | |
| val cfOpts = ColumnFamilyOptions().setComparator(BuiltinComparator.BYTEWISE_COMPARATOR) | |
| val pathStr = path.toString() | |
| val existing: List<ByteArray> = try { RocksDB.listColumnFamilies(Options(), pathStr) } catch (_: RocksDBException) { emptyList() } | |
| val descriptors = if (existing.isEmpty()) listOf(ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, cfOpts)) | |
| else existing.map { name -> ColumnFamilyDescriptor(name, cfOpts) } | |
| val hs = mutableListOf<ColumnFamilyHandle>() | |
| db = TransactionDB.open(dbOpts, TransactionDBOptions(), pathStr, descriptors, hs) | |
| descriptors.forEachIndexed { i, d -> handles[String(d.name)] = hs[i] } | |
| } | |
| fun <K : Any, V : Any> openTree(name: String, ks: BinarySerializer<K>, vs: BinarySerializer<V>): TreeTable<K, V> { | |
| val h = handles.getOrPut(name) { | |
| val cfo = ColumnFamilyOptions().setComparator(BuiltinComparator.BYTEWISE_COMPARATOR) | |
| db.createColumnFamily(ColumnFamilyDescriptor(name.toByteArray(), cfo)) | |
| } | |
| return TreeTable(db, h, ks, vs) | |
| } | |
| override fun close() { handles.values.forEach { it.close() }; db.close() } | |
| } | |
| class TreeTable<K : Any, V : Any>( | |
| private val db: TransactionDB, | |
| private val h: ColumnFamilyHandle, | |
| private val ks: BinarySerializer<K>, | |
| private val vs: BinarySerializer<V> | |
| ) { | |
| private fun kb(k: K) = ks.serialize(k) | |
| private fun decode(it: RocksIterator) = ks.deserialize(it.key()) to vs.deserialize(it.value()) | |
| fun begin(): Txn = Txn(db.beginTransaction(WriteOptions())) | |
| fun <R> transaction(block: (Txn) -> R): R { | |
| val tx = begin() | |
| try { val r = block(tx); tx.commit(); return r } catch (e: Throwable) { tx.rollback(); throw e } finally { tx.close() } | |
| } | |
| fun get(k: K, txn: Txn? = null): V? { | |
| val b = if (txn == null) db.get(h, kb(k)) else txn.tx.get(ReadOptions(), h, kb(k)) | |
| return b?.let(vs::deserialize) | |
| } | |
| fun getForUpdate(k: K, txn: Txn): V? { | |
| val b = txn.tx.getForUpdate(ReadOptions(), h, kb(k), true) | |
| return b?.let(vs::deserialize) | |
| } | |
| fun put(k: K, v: V, txn: Txn? = null) { | |
| val key = kb(k); val value = vs.serialize(v) | |
| if (txn == null) db.put(h, key, value) else txn.tx.put(h, key, value) | |
| } | |
| fun delete(k: K, txn: Txn? = null) { | |
| val key = kb(k) | |
| if (txn == null) db.delete(h, key) else txn.tx.delete(h, key) | |
| } | |
| fun contains(k: K, txn: Txn? = null) = get(k, txn)!= null | |
| fun first(): Pair<K, V>? = db.newIterator(h).use { it.seekToFirst(); if (it.isValid) decode(it) else null } | |
| fun last(): Pair<K, V>? = db.newIterator(h).use { it.seekToLast(); if (it.isValid) decode(it) else null } | |
| fun ceiling(k: K): Pair<K, V>? = db.newIterator(h).use { it.seek(kb(k)); if (it.isValid) decode(it) else null } | |
| fun higher(k: K): Pair<K, V>? = db.newIterator(h).use { | |
| val key = kb(k); it.seek(key); if (!it.isValid) return@use null; if (it.key().contentEquals(key)) it.next(); if (it.isValid) decode(it) else null | |
| } | |
| fun floor(k: K): Pair<K, V>? = db.newIterator(h).use { | |
| val key = kb(k); it.seek(key); if (!it.isValid) { it.seekToLast(); return@use if (it.isValid) decode(it) else null }; if (it.key().contentEquals(key)) return@use decode(it); it.prev(); if (it.isValid) decode(it) else null | |
| } | |
| fun lower(k: K): Pair<K, V>? = db.newIterator(h).use { it.seek(kb(k)); it.prev(); if (it.isValid) decode(it) else null } | |
| fun scan(from: K? = null, fromInclusive: Boolean = true, to: K? = null, toInclusive: Boolean = false, limit: Int = Int.MAX_VALUE): Sequence<Pair<K, V>> = sequence { | |
| db.newIterator(h).use { it -> | |
| val fromB = from?.let(::kb); val toB = to?.let(::kb) | |
| if (fromB == null) it.seekToFirst() else it.seek(fromB) | |
| if (fromB!= null &&!fromInclusive && it.isValid && it.key().contentEquals(fromB)) it.next() | |
| var n = 0 | |
| while (it.isValid && n < limit) { | |
| val key = it.key() | |
| if (toB!= null) { val cmp = key.compareTo(toB); if (cmp > 0 || (cmp == 0 &&!toInclusive)) break } | |
| yield(decode(it)); it.next(); n++ | |
| } | |
| } | |
| } | |
| fun scanPrefix(prefix: ByteArray, limit: Int = Int.MAX_VALUE): Sequence<Pair<K, V>> = sequence { | |
| db.newIterator(h).use { it -> | |
| it.seek(prefix); var n = 0 | |
| while (it.isValid && n < limit && it.key().startsWith(prefix)) { yield(decode(it)); it.next(); n++ } | |
| } | |
| } | |
| private fun ByteArray.compareTo(o: ByteArray): Int { | |
| val l = minOf(size, o.size) | |
| for (i in 0 until l) { val d = (this[i].toInt() and 0xff) - (o[i].toInt() and 0xff); if (d!= 0) return d } | |
| return size - o.size | |
| } | |
| private fun ByteArray.startsWith(p: ByteArray): Boolean { | |
| if (size < p.size) return false | |
| for (i in p.indices) if (this[i]!= p[i]) return false | |
| return true | |
| } | |
| } | |
| // ---------- example usage ---------- | |
| data class BugReport(val title: String, val priority: Int, val created: Instant, val id: String = UlidFactory.newInstance().create().toString()) | |
| data class BugReport2(val priority: Int, val created: Instant, val id: String = UlidFactory.newInstance().create().toString()) | |
| data class BugReportByPriorityIndex(val priority: Int, @Desc val created: Instant) | |
| class Versioned<V : Any>(private val current: Int, private val codecs: Map<Int, BinarySerializer<V>>, private val migrate: (Any) -> V) : BinarySerializer<V> { | |
| override fun serialize(v: V): ByteArray { | |
| val body = codecs[current]!!.serialize(v) | |
| return byteArrayOf(current.toByte()) + body | |
| } | |
| override fun deserialize(b: ByteArray): V { | |
| val ver = b[0].toInt() | |
| val body = b.copyOfRange(1, b.size) | |
| val old = codecs[ver]!!.deserialize(body) | |
| return if (ver == current) old else migrate(old) | |
| } | |
| } | |
| @Target(AnnotationTarget.PROPERTY) | |
| @Retention(AnnotationRetention.RUNTIME) | |
| annotation class Desc | |
| fun main() { | |
| val db = RocksDbDatabase(Path.of("./data")) | |
| val bugPK = autoCodec<BugReportByPriorityIndex>() | |
| val v1 = autoCodec<BugReport>() | |
| val bugsDB: TreeTable<BugReportByPriorityIndex, BugReport> = db.openTree("bugs_by_priority", bugPK, v1) | |
| val key = BugReportByPriorityIndex(1, Instant.now()) | |
| bugsDB.put(key, BugReport("jvm getting npe", 1, Instant.now())) | |
| bugsDB.put(BugReportByPriorityIndex(1, Instant.now()), BugReport("highest priorty", 1, Instant.now())) | |
| println(bugsDB.first()) | |
| println(bugsDB.last()) | |
| db.close() | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment