Skip to content

Instantly share code, notes, and snippets.

@rust-play
Created March 6, 2023 13:09
Show Gist options
  • Save rust-play/2cf33c714e72e3377845c7bf2fdc8619 to your computer and use it in GitHub Desktop.
Save rust-play/2cf33c714e72e3377845c7bf2fdc8619 to your computer and use it in GitHub Desktop.
Code shared from the Rust Playground
use std::path::Path;
use core::marker::PhantomData;
use core::borrow::Borrow;
use chrono::{DateTime, Utc, Days};
use rusqlite::{named_params, Connection, Statement};
use serde::{Serialize, Deserialize};
use thiserror::Error;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[derive(Debug, Error)]
pub enum Error {
#[error("SQLite error: {0}")]
SQLite(#[from] rusqlite::Error),
#[error("SQL type conversion error: {0}")]
FromSql(#[from] rusqlite::types::FromSqlError),
#[error("invalid name `{0}`")]
InvalidName(String),
#[error("encoding error: {0}")]
Encoding(#[from] serde_json::Error),
#[error("value not found for key")]
KeyNotFound,
}
pub struct Store {
conn: Connection,
}
impl Store {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let conn = Connection::open(path)?;
Ok(Store { conn })
}
pub fn collection<K, V>(&self, name: &str) -> Result<Collection<'_, K, V>> {
if !name.chars().all(|ch| ch.is_ascii_alphanumeric() || ch == '_') {
return Err(Error::InvalidName(name.to_owned()));
}
let create_sql = format!(
r#"
CREATE TABLE IF NOT EXISTS "{}"(
"key" BLOB NOT NULL PRIMARY KEY,
"value" BLOB NOT NULL,
"expiry" REAL NULL -- NULL means never expires
)"#,
name
);
self.conn.execute(&create_sql, [])?;
let get_sql = format!(
r#"
SELECT "value" AS "value"
FROM "{}"
WHERE
"key" = :key
AND
("expiry" IS NULL OR JULIANDAY() < "expiry")
"#,
name
);
let get_stmt = self.conn.prepare(&get_sql)?;
let set_sql = format!(
r#"
INSERT OR REPLACE INTO "{}"("key", "value", "expiry")
VALUES (:key, :value, JULIANDAY(:expiry))
"#,
name
);
let set_stmt = self.conn.prepare(&set_sql)?;
let del_sql = format!(
r#"
DELETE FROM "{}"
WHERE "key" = :key
"#,
name
);
let del_stmt = self.conn.prepare(&del_sql)?;
let cln_sql = format!(
r#"
DELETE FROM "{}"
WHERE
"expiry" IS NOT NULL
AND
"expiry" < JULIANDAY()
"#,
name
);
let cln_stmt = self.conn.prepare(&cln_sql)?;
Ok(Collection {
get_stmt,
set_stmt,
del_stmt,
cln_stmt,
buf: Vec::new(),
marker: PhantomData,
})
}
}
pub struct Collection<'store, K, V> {
get_stmt: Statement<'store>,
set_stmt: Statement<'store>,
del_stmt: Statement<'store>,
cln_stmt: Statement<'store>,
buf: Vec<u8>,
marker: PhantomData<fn() -> (K, V)>,
}
impl<'store, K, V> Collection<'store, K, V> {
pub fn get<Q>(&mut self, key: &Q) -> Result<V>
where
K: Eq + Borrow<Q>,
V: for<'de> Deserialize<'de>,
Q: ?Sized + Eq + Serialize,
{
self.buf.clear();
serde_json::to_writer(&mut self.buf, key.borrow())?;
let mut rows = self.get_stmt.query(named_params!{":key": &self.buf[..]})?;
let maybe_row = rows.next()?;
let row = maybe_row.ok_or(Error::KeyNotFound)?;
let bytes = row.get_ref("value")?.as_bytes()?;
let value: V = serde_json::from_slice(bytes)?;
Ok(value)
}
pub fn get_opt<Q>(&mut self, key: &Q) -> Result<Option<V>>
where
K: Eq + Borrow<Q>,
V: for<'de> Deserialize<'de>,
Q: ?Sized + Eq + Serialize,
{
self.buf.clear();
serde_json::to_writer(&mut self.buf, key.borrow())?;
let mut rows = self.get_stmt.query(named_params!{":key": &self.buf[..]})?;
let row = match rows.next()? {
Some(r) => r,
None => return Ok(None),
};
let bytes = row.get_ref("value")?.as_bytes()?;
let value: V = serde_json::from_slice(bytes)?;
Ok(Some(value))
}
pub fn set<Q, U, E>(&mut self, key: &Q, value: &U, expiry: E) -> Result<()>
where
K: Eq + Borrow<Q>,
V: Borrow<U>,
Q: ?Sized + Eq + Serialize,
U: ?Sized + Serialize,
E: Into<Option<DateTime<Utc>>>,
{
self.buf.clear();
serde_json::to_writer(&mut self.buf, key.borrow())?;
let key_len = self.buf.len();
serde_json::to_writer(&mut self.buf, value.borrow())?;
let (key_bytes, value_bytes) = self.buf.split_at(key_len);
self.set_stmt.execute(named_params!{
":key": key_bytes,
":value": value_bytes,
":expiry": expiry.into(),
})?;
Ok(())
}
pub fn remove<Q>(&mut self, key: &Q) -> Result<bool>
where
K: Eq + Borrow<Q>,
Q: ?Sized + Eq + Serialize,
{
self.buf.clear();
serde_json::to_writer(&mut self.buf, key.borrow())?;
let num_rows = self.del_stmt.execute(named_params!{":key": &self.buf[..]})?;
Ok(num_rows > 0)
}
/// remove expired entries
pub fn cleanup(&mut self) -> Result<()> {
self.cln_stmt.execute([])?;
Ok(())
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
enum Color {
Red,
Green,
Blue,
}
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
struct User {
handle: String,
age: u32,
favorite_color: Color,
}
fn main() -> Result<()> {
let store = Store::open("/tmp/store.db")?;
let mut coll: Collection<String, User> = store.collection("people")?;
coll.set("foo", &User {
handle: String::from("John Foo"),
age: 63,
favorite_color: Color::Red,
}, Utc::now().checked_add_days(Days::new(1)))?; // expiry in 1 day
coll.set("bar", &User {
handle: String::from("Alice Bar"),
age: 27,
favorite_color: Color::Blue,
}, None)?; // does not expire
let u1 = coll.get("foo")?;
dbg!(u1);
let u2 = coll.get_opt("nonexistent")?;
dbg!(u2);
assert_eq!(coll.remove("bar")?, true);
assert_eq!(coll.remove("something")?, false);
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment