Created
August 30, 2022 18:43
-
-
Save xacrimon/f429410cb724423201ec22a95b1db9df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
fs, | |
sync::{Arc, Mutex, MutexGuard}, | |
thread, | |
time::{Duration, Instant}, | |
}; | |
use anyhow::{anyhow, bail, Result}; | |
use axum::async_trait; | |
use rusqlite::{Connection, TransactionBehavior}; | |
use serde::{Deserialize, Serialize}; | |
use serde_json::json; | |
use tokio::task; | |
use tracing::{debug, info, instrument, warn}; | |
use super::{ | |
config::Config, | |
jobs::{Job, Schedulable, State}, | |
}; | |
static DB_SETTINGS: &[(&str, &str)] = &[ | |
("journal_mode", "delete"), | |
("synchronous", "full"), | |
("cache_size", "-8192"), | |
("busy_timeout", "100"), | |
("temp_store", "memory"), | |
]; | |
const DB_TIMEOUT: Duration = Duration::from_secs(10); | |
const DB_FILE_NAME: &str = "persistent"; | |
const CACHED_QUERIES: usize = 256; | |
pub struct Database { | |
conn: Mutex<Connection>, | |
} | |
impl Database { | |
pub fn new(config: &Config) -> Result<Arc<Self>> { | |
debug!("detected sqlite library version: {}", rusqlite::version()); | |
info!("connecting to database..."); | |
let path = config.instance.data_directory.join(DB_FILE_NAME); | |
let mut conn = Connection::open(path)?; | |
conn.set_prepared_statement_cache_capacity(CACHED_QUERIES); | |
apply_pragmas(&conn)?; | |
apply_migrations(&mut conn)?; | |
Ok(Arc::new(Self { | |
conn: Mutex::new(conn), | |
})) | |
} | |
pub fn acquire_conn(&self) -> MutexGuard<Connection> { | |
self.conn.lock().unwrap() | |
} | |
} | |
fn apply_pragmas(conn: &Connection) -> Result<()> { | |
for (pragma, value) in DB_SETTINGS { | |
conn.pragma_update(None, pragma, value)?; | |
} | |
Ok(()) | |
} | |
fn apply_migrations(conn: &mut Connection) -> Result<()> { | |
let tx = conn.transaction_with_behavior(TransactionBehavior::Exclusive)?; | |
let current_version: i32 = | |
tx.query_row("SELECT user_version FROM pragma_user_version", [], |row| { | |
row.get(0) | |
})?; | |
let migrations = load_migrations(current_version)?; | |
for (name, version, migration) in migrations { | |
if version > current_version { | |
info!("applying migration: {}...", name); | |
tx.execute_batch(&migration)?; | |
tx.pragma_update(None, "user_version", version)?; | |
} | |
} | |
tx.commit()?; | |
Ok(()) | |
} | |
fn load_migrations(above: i32) -> Result<Vec<(String, i32, String)>> { | |
let mut migrations = Vec::new(); | |
for entry in fs::read_dir("./migrations")? { | |
let entry = entry?; | |
let (name, version) = extract_key_from_entry(&entry)?; | |
let migration = fs::read_to_string(entry.path())?; | |
if version > above { | |
migrations.push((name, version, migration)); | |
} | |
} | |
migrations.sort_by_key(|(_, num, _)| *num); | |
Ok(migrations) | |
} | |
fn extract_key_from_entry(entry: &fs::DirEntry) -> Result<(String, i32)> { | |
let raw_name = entry.file_name(); | |
let name = raw_name | |
.to_str() | |
.ok_or_else(|| anyhow!("invalid file name"))?; | |
let num = name | |
.split('-') | |
.next() | |
.ok_or_else(|| anyhow!("missing number in migration name"))? | |
.parse()?; | |
Ok((name.to_owned(), num)) | |
} | |
trait QueryMode { | |
type Handle<'conn>; | |
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>>; | |
} | |
struct NoTx; | |
impl QueryMode for NoTx { | |
type Handle<'conn> = &'conn rusqlite::Connection; | |
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>> { | |
Ok(connection) | |
} | |
} | |
struct Tx; | |
impl QueryMode for Tx { | |
type Handle<'conn> = rusqlite::Transaction<'conn>; | |
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>> { | |
connection.transaction() | |
} | |
} | |
#[instrument(skip(db, query), err)] | |
fn query_inner<M, F, T>(db: &Database, mut query: F) -> Result<T> | |
where | |
M: QueryMode, | |
F: FnMut(M::Handle<'_>) -> Result<T>, | |
{ | |
task::block_in_place(|| { | |
let start = Instant::now(); | |
let end = start + DB_TIMEOUT; | |
let conn = &mut db.acquire_conn(); | |
let ret = loop { | |
let handle = M::handle(conn)?; | |
match query(handle) { | |
Ok(item) => break Ok(item), | |
Err(ref err) if let Some(sq_err) = err.downcast_ref::<rusqlite::Error>() => { | |
if let Some(code) = sq_err.sqlite_error_code() { | |
if code == rusqlite::ErrorCode::DatabaseBusy { | |
warn!("database is busy, retrying"); | |
if Instant::now() > end { | |
bail!("database busy, timed out"); | |
} else { | |
thread::yield_now(); | |
continue; | |
} | |
} | |
} | |
} | |
Err(err) => break Err(err), | |
} | |
}; | |
debug!("transaction took {:?}", start.elapsed()); | |
ret | |
}) | |
} | |
pub fn query<F, T>(db: &Database, query: F) -> Result<T> | |
where | |
F: FnMut(&rusqlite::Connection) -> Result<T>, | |
{ | |
query_inner::<NoTx, _, _>(db, query) | |
} | |
pub fn query_tx<F, T>(db: &Database, query: F) -> Result<T> | |
where | |
F: FnMut(rusqlite::Transaction) -> Result<T>, | |
{ | |
query_inner::<Tx, _, _>(db, query) | |
} | |
#[derive(Default, Serialize, Deserialize)] | |
pub struct DatabaseAnalyzeJob; | |
#[async_trait] | |
#[typetag::serde] | |
impl Job for DatabaseAnalyzeJob { | |
async fn run(&self, state: &State) -> Result<serde_json::Value> { | |
query(&state.db, |conn| { | |
conn.execute_batch( | |
r#" | |
PRAGMA analysis_limit=400; | |
ANALYZE; | |
"#, | |
)?; | |
Ok(()) | |
})?; | |
Ok(json!({})) | |
} | |
} | |
impl Schedulable for DatabaseAnalyzeJob { | |
const INTERVAL: Duration = Duration::from_secs(60 * 60 * 24); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment