Created
December 17, 2024 15:15
-
-
Save xkikeg/66ed868d5137913dac0d5ecf9401d65d to your computer and use it in GitHub Desktop.
Initial impl for the price_db
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use chrono::{Days, NaiveDate}; | |
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; | |
use snippets::price_db::{BTreeCachedPriceDB, CachedPriceDB, FakeCommodityStore, PriceDBHashMap}; | |
fn price_db_bench(c: &mut Criterion) { | |
let mut group = c.benchmark_group("price-db"); | |
for num_dense in [5usize, 10, 20].iter() { | |
group.bench_with_input( | |
BenchmarkId::new("naive", num_dense), | |
num_dense, | |
|b, num_dense| { | |
let mut price_db = PriceDBHashMap::default(); | |
static SCALE_SPARSE: usize = 10; | |
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE); | |
static YEARS: u64 = 10; | |
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); | |
price_db.fill(&commodity_store, base, YEARS); | |
let mut i = commodity_store | |
.all | |
.iter() | |
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len())) | |
.cycle(); | |
let mut j = commodity_store.all.iter().cycle(); | |
let mut d = (0..365 * (YEARS + 1)).cycle(); | |
b.iter(|| { | |
let days = d.next().unwrap(); | |
let date = base.checked_add_days(Days::new(days)).unwrap(); | |
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date)); | |
}) | |
}, | |
); | |
group.bench_with_input( | |
BenchmarkId::new("cached", num_dense), | |
num_dense, | |
|b, num_dense| { | |
let mut price_db = PriceDBHashMap::default(); | |
static SCALE_SPARSE: usize = 10; | |
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE); | |
static YEARS: u64 = 10; | |
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); | |
price_db.fill(&commodity_store, base, YEARS); | |
let mut price_db = CachedPriceDB::new(price_db); | |
let i = commodity_store.all.iter().next().unwrap(); | |
let j = commodity_store.all.iter().skip(1).next().unwrap(); | |
for d in 0..365 * (YEARS + 1) { | |
let days = d; | |
let date = base.checked_add_days(Days::new(days)).unwrap(); | |
black_box(price_db.compute_price(*i, *j, date)); | |
} | |
let mut i = commodity_store | |
.all | |
.iter() | |
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len())) | |
.cycle(); | |
let mut j = commodity_store.all.iter().cycle(); | |
let mut d = (0..365 * (YEARS + 1)).cycle(); | |
b.iter(|| { | |
let days = d.next().unwrap(); | |
let date = base.checked_add_days(Days::new(days)).unwrap(); | |
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date)); | |
}) | |
}, | |
); | |
group.bench_with_input( | |
BenchmarkId::new("btree-cached", num_dense), | |
num_dense, | |
|b, num_dense| { | |
let mut price_db = PriceDBHashMap::default(); | |
static SCALE_SPARSE: usize = 10; | |
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE); | |
static YEARS: u64 = 10; | |
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); | |
price_db.fill(&commodity_store, base, YEARS); | |
let mut price_db = BTreeCachedPriceDB::new(price_db); | |
let i = commodity_store.all.iter().next().unwrap(); | |
let j = commodity_store.all.iter().skip(1).next().unwrap(); | |
for d in 0..365 * (YEARS + 1) { | |
let days = d; | |
let date = base.checked_add_days(Days::new(days)).unwrap(); | |
black_box(price_db.compute_price(*i, *j, date)); | |
} | |
let mut i = commodity_store | |
.all | |
.iter() | |
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len())) | |
.cycle(); | |
let mut j = commodity_store.all.iter().cycle(); | |
let mut d = (0..365 * (YEARS + 1)).cycle(); | |
b.iter(|| { | |
let days = d.next().unwrap(); | |
let date = base.checked_add_days(Days::new(days)).unwrap(); | |
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date)); | |
}) | |
}, | |
); | |
} | |
group.finish(); | |
} | |
criterion_group!(benches, price_db_bench); | |
criterion_main!(benches); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub type Commodity = u64; | |
use std::collections::{hash_map, BTreeMap, BTreeSet, BinaryHeap, HashMap}; | |
use chrono::{Days, NaiveDate, TimeDelta}; | |
use rust_decimal::Decimal; | |
// struct PriceDBBtree { | |
// records: BTreeMap<(u64, u64, NaiveDate), Decimal>, | |
// } | |
#[derive(Debug, Default)] | |
pub struct PriceDBHashMap { | |
// from comodity -> to commodity -> date -> price. | |
// e.g. USD AAPL 2024-01-01 100 means 1 AAPL == 100 USD at 2024-01-01. | |
records: HashMap<u64, HashMap<u64, Vec<(NaiveDate, Decimal)>>>, | |
} | |
impl PriceDBHashMap { | |
fn insert_last_price( | |
&mut self, | |
price_by: Commodity, | |
price_of: Commodity, | |
date: NaiveDate, | |
rate: Decimal, | |
) { | |
self.records | |
.entry(price_by) | |
.or_default() | |
.entry(price_of) | |
.or_default() | |
.push((date, rate)); | |
} | |
pub fn fill(&mut self, commodity_store: &FakeCommodityStore, base: NaiveDate, years: u64) { | |
for i in 0..365 * years { | |
let date = base.checked_add_days(Days::new(i)).unwrap(); | |
// dense pair | |
let mut i = commodity_store.dense.iter(); | |
while let Some(iv) = i.next() { | |
let mut j = i.clone(); | |
while let Some(jv) = j.next() { | |
let i = *iv; | |
let j = *jv; | |
// some other value? | |
let rate = Decimal::from(1); | |
self.insert_last_price(i, j, date, rate); | |
self.insert_last_price(j, i, date, Decimal::from(1) / rate); | |
} | |
} | |
for (i, j) in &commodity_store.sparse { | |
let rate = Decimal::from(100); | |
self.insert_last_price(*j, *i, date, rate); | |
self.insert_last_price(*i, *j, date, Decimal::ONE / rate); | |
} | |
} | |
} | |
pub fn compute_price( | |
&self, | |
price_of: u64, | |
price_with: u64, | |
date: NaiveDate, | |
) -> Option<Decimal> { | |
if price_of == price_with { | |
return Some(Decimal::ONE); | |
} | |
match self.compute_price_table(price_with, date).get(&price_of) { | |
None => None, | |
Some((_, _, rate)) => Some(*rate), | |
} | |
} | |
fn compute_price_table( | |
&self, | |
price_with: u64, | |
date: NaiveDate, | |
) -> HashMap<Commodity, (usize, TimeDelta, Decimal)> { | |
// minimize the distance, and then minimize the staleness. | |
let mut queue: BinaryHeap<(usize, TimeDelta, u64, Decimal)> = BinaryHeap::new(); | |
let mut distances: HashMap<Commodity, (usize, TimeDelta, Decimal)> = HashMap::new(); | |
let mut prevs: HashMap<u64, u64> = HashMap::new(); | |
queue.push((0, TimeDelta::zero(), price_with, Decimal::ONE)); | |
while let Some(curr) = queue.pop() { | |
let (curr_dist, staleness, prev, prev_rate) = curr; | |
log::debug!("curr: {:?}", curr); | |
if let Some(prev_dist) = distances.get(&prev) { | |
if (prev_dist.0, prev_dist.1) < (curr_dist, staleness) { | |
// no need to update, as it's worse than discovered path. | |
log::debug!( | |
"no need to update, prev_dist {:?} is smaller than curr_dist {:?}", | |
prev_dist, | |
curr | |
); | |
continue; | |
} | |
} | |
for (j, rates) in match self.records.get(&prev) { | |
None => continue, | |
Some(x) => x, | |
} { | |
let bound = rates.partition_point(|(record_date, _)| record_date <= &date); | |
log::debug!("found next commodity {} bound {}", j, bound); | |
if bound == 0 { | |
continue; | |
} | |
let (record_date, rate) = rates[bound - 1]; | |
let staleness = std::cmp::max(staleness, date - record_date); | |
let rate = prev_rate * rate; | |
let next_dist = (curr_dist + 1, staleness, rate); | |
let next = (curr_dist + 1, staleness, *j, rate); | |
let updated = match distances.entry(*j) { | |
hash_map::Entry::Occupied(mut e) => { | |
if e.get() <= &next_dist { | |
false | |
} else { | |
e.insert(next_dist); | |
prevs.insert(*j, prev); | |
true | |
} | |
} | |
hash_map::Entry::Vacant(e) => { | |
e.insert(next_dist); | |
prevs.insert(*j, prev); | |
true | |
} | |
}; | |
if !updated { | |
continue; | |
} | |
queue.push(next); | |
} | |
} | |
distances | |
} | |
} | |
pub struct CachedPriceDB { | |
inner: PriceDBHashMap, | |
cache: HashMap<NaiveDate, HashMap<Commodity, (usize, TimeDelta, Decimal)>>, | |
} | |
impl CachedPriceDB { | |
pub fn new(inner: PriceDBHashMap) -> Self { | |
Self { | |
inner, | |
cache: HashMap::new(), | |
} | |
} | |
pub fn compute_price( | |
&mut self, | |
price_of: u64, | |
price_with: u64, | |
date: NaiveDate, | |
) -> Option<Decimal> { | |
if price_of == price_with { | |
return Some(Decimal::ONE); | |
} | |
self.cache | |
.entry(date) | |
.or_insert_with(|| self.inner.compute_price_table(price_with, date)) | |
.get(&price_of) | |
.map(|(_, _, rate)| *rate) | |
} | |
} | |
pub struct BTreeCachedPriceDB { | |
inner: PriceDBHashMap, | |
cache: BTreeMap<NaiveDate, HashMap<Commodity, (usize, TimeDelta, Decimal)>>, | |
} | |
impl BTreeCachedPriceDB { | |
pub fn new(inner: PriceDBHashMap) -> Self { | |
Self { | |
inner, | |
cache: BTreeMap::new(), | |
} | |
} | |
pub fn compute_price( | |
&mut self, | |
price_of: u64, | |
price_with: u64, | |
date: NaiveDate, | |
) -> Option<Decimal> { | |
if price_of == price_with { | |
return Some(Decimal::ONE); | |
} | |
self.cache | |
.entry(date) | |
.or_insert_with(|| self.inner.compute_price_table(price_with, date)) | |
.get(&price_of) | |
.map(|(_, _, rate)| *rate) | |
} | |
} | |
pub struct FakeCommodityStore { | |
pub all: BTreeSet<u64>, | |
// dense has generally 1:1 mapping each other. | |
dense: BTreeSet<u64>, | |
// sparse commodity has only one link to the dense commodity. | |
sparse: BTreeMap<u64, u64>, | |
} | |
impl FakeCommodityStore { | |
fn new_commodity(&mut self) -> u64 { | |
let mut nc: u64; | |
loop { | |
nc = rand::random(); | |
if self.all.insert(nc) { | |
break; | |
} | |
} | |
nc | |
} | |
pub fn new(dense_size: usize, sparse_scale: usize) -> Self { | |
let mut ret = Self { | |
all: Default::default(), | |
dense: Default::default(), | |
sparse: Default::default(), | |
}; | |
for _ in 0..dense_size { | |
let dense = ret.new_commodity(); | |
ret.dense.insert(dense); | |
for _ in 0..sparse_scale { | |
let sparse = ret.new_commodity(); | |
ret.sparse.insert(sparse, dense); | |
} | |
} | |
ret | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use rust_decimal_macros::dec; | |
#[ctor::ctor] | |
fn init() { | |
let _ = env_logger::builder().is_test(true).try_init(); | |
} | |
#[test] | |
fn price_db_computes_direct_price() { | |
let mut db = PriceDBHashMap::default(); | |
let chf: Commodity = 10; | |
let eur: Commodity = 15; | |
db.insert_last_price( | |
chf, | |
eur, | |
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(), | |
dec!(0.8), | |
); | |
db.insert_last_price( | |
eur, | |
chf, | |
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(), | |
Decimal::ONE / dec!(0.8), | |
); | |
let got = db.compute_price(eur, chf, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap()); | |
assert_eq!(got, Some(dec!(0.8))); | |
let got = db.compute_price(chf, eur, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap()); | |
assert_eq!(got, Some(dec!(1.25))); | |
let got = db.compute_price(chf, eur, NaiveDate::from_ymd_opt(2024, 9, 1).unwrap()); | |
assert_eq!(got, None); | |
} | |
#[test] | |
fn price_db_computes_indirect_price() { | |
let mut db = PriceDBHashMap::default(); | |
let chf: Commodity = 10; | |
let eur: Commodity = 15; | |
let usd: Commodity = 1001; | |
let jpy: Commodity = 1002456; | |
let date = NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(); | |
db.insert_last_price(chf, eur, date, dec!(0.8)); | |
db.insert_last_price(eur, chf, date, Decimal::ONE / dec!(0.8)); | |
db.insert_last_price(eur, usd, date, dec!(0.8)); | |
db.insert_last_price(usd, eur, date, Decimal::ONE / dec!(0.8)); | |
db.insert_last_price(jpy, usd, date, dec!(100)); | |
db.insert_last_price(usd, jpy, date, Decimal::ONE / dec!(0.01)); | |
// 1 EUR = 0.8 CHF | |
// 1 USD = 0.8 EUR | |
// 1 USD = 100 JPY | |
// 1 CHF == 5/4 EUR == (5/4)*(5/4) USD == 156.25 JPY | |
let got = db.compute_price(chf, jpy, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap()); | |
assert_eq!(got, Some(dec!(156.25))); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment