Skip to content

Instantly share code, notes, and snippets.

@xkikeg
Created December 17, 2024 15:15
Show Gist options
  • Save xkikeg/66ed868d5137913dac0d5ecf9401d65d to your computer and use it in GitHub Desktop.
Save xkikeg/66ed868d5137913dac0d5ecf9401d65d to your computer and use it in GitHub Desktop.
Initial impl for the price_db
use chrono::{Days, NaiveDate};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use snippets::price_db::{BTreeCachedPriceDB, CachedPriceDB, FakeCommodityStore, PriceDBHashMap};
fn price_db_bench(c: &mut Criterion) {
let mut group = c.benchmark_group("price-db");
for num_dense in [5usize, 10, 20].iter() {
group.bench_with_input(
BenchmarkId::new("naive", num_dense),
num_dense,
|b, num_dense| {
let mut price_db = PriceDBHashMap::default();
static SCALE_SPARSE: usize = 10;
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE);
static YEARS: u64 = 10;
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap();
price_db.fill(&commodity_store, base, YEARS);
let mut i = commodity_store
.all
.iter()
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len()))
.cycle();
let mut j = commodity_store.all.iter().cycle();
let mut d = (0..365 * (YEARS + 1)).cycle();
b.iter(|| {
let days = d.next().unwrap();
let date = base.checked_add_days(Days::new(days)).unwrap();
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date));
})
},
);
group.bench_with_input(
BenchmarkId::new("cached", num_dense),
num_dense,
|b, num_dense| {
let mut price_db = PriceDBHashMap::default();
static SCALE_SPARSE: usize = 10;
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE);
static YEARS: u64 = 10;
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap();
price_db.fill(&commodity_store, base, YEARS);
let mut price_db = CachedPriceDB::new(price_db);
let i = commodity_store.all.iter().next().unwrap();
let j = commodity_store.all.iter().skip(1).next().unwrap();
for d in 0..365 * (YEARS + 1) {
let days = d;
let date = base.checked_add_days(Days::new(days)).unwrap();
black_box(price_db.compute_price(*i, *j, date));
}
let mut i = commodity_store
.all
.iter()
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len()))
.cycle();
let mut j = commodity_store.all.iter().cycle();
let mut d = (0..365 * (YEARS + 1)).cycle();
b.iter(|| {
let days = d.next().unwrap();
let date = base.checked_add_days(Days::new(days)).unwrap();
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date));
})
},
);
group.bench_with_input(
BenchmarkId::new("btree-cached", num_dense),
num_dense,
|b, num_dense| {
let mut price_db = PriceDBHashMap::default();
static SCALE_SPARSE: usize = 10;
let commodity_store = FakeCommodityStore::new(*num_dense, SCALE_SPARSE);
static YEARS: u64 = 10;
let base = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap();
price_db.fill(&commodity_store, base, YEARS);
let mut price_db = BTreeCachedPriceDB::new(price_db);
let i = commodity_store.all.iter().next().unwrap();
let j = commodity_store.all.iter().skip(1).next().unwrap();
for d in 0..365 * (YEARS + 1) {
let days = d;
let date = base.checked_add_days(Days::new(days)).unwrap();
black_box(price_db.compute_price(*i, *j, date));
}
let mut i = commodity_store
.all
.iter()
.flat_map(|x| std::iter::repeat_n(*x, commodity_store.all.len()))
.cycle();
let mut j = commodity_store.all.iter().cycle();
let mut d = (0..365 * (YEARS + 1)).cycle();
b.iter(|| {
let days = d.next().unwrap();
let date = base.checked_add_days(Days::new(days)).unwrap();
black_box(price_db.compute_price(i.next().unwrap(), *j.next().unwrap(), date));
})
},
);
}
group.finish();
}
criterion_group!(benches, price_db_bench);
criterion_main!(benches);
pub type Commodity = u64;
use std::collections::{hash_map, BTreeMap, BTreeSet, BinaryHeap, HashMap};
use chrono::{Days, NaiveDate, TimeDelta};
use rust_decimal::Decimal;
// struct PriceDBBtree {
// records: BTreeMap<(u64, u64, NaiveDate), Decimal>,
// }
#[derive(Debug, Default)]
pub struct PriceDBHashMap {
// from comodity -> to commodity -> date -> price.
// e.g. USD AAPL 2024-01-01 100 means 1 AAPL == 100 USD at 2024-01-01.
records: HashMap<u64, HashMap<u64, Vec<(NaiveDate, Decimal)>>>,
}
impl PriceDBHashMap {
fn insert_last_price(
&mut self,
price_by: Commodity,
price_of: Commodity,
date: NaiveDate,
rate: Decimal,
) {
self.records
.entry(price_by)
.or_default()
.entry(price_of)
.or_default()
.push((date, rate));
}
pub fn fill(&mut self, commodity_store: &FakeCommodityStore, base: NaiveDate, years: u64) {
for i in 0..365 * years {
let date = base.checked_add_days(Days::new(i)).unwrap();
// dense pair
let mut i = commodity_store.dense.iter();
while let Some(iv) = i.next() {
let mut j = i.clone();
while let Some(jv) = j.next() {
let i = *iv;
let j = *jv;
// some other value?
let rate = Decimal::from(1);
self.insert_last_price(i, j, date, rate);
self.insert_last_price(j, i, date, Decimal::from(1) / rate);
}
}
for (i, j) in &commodity_store.sparse {
let rate = Decimal::from(100);
self.insert_last_price(*j, *i, date, rate);
self.insert_last_price(*i, *j, date, Decimal::ONE / rate);
}
}
}
pub fn compute_price(
&self,
price_of: u64,
price_with: u64,
date: NaiveDate,
) -> Option<Decimal> {
if price_of == price_with {
return Some(Decimal::ONE);
}
match self.compute_price_table(price_with, date).get(&price_of) {
None => None,
Some((_, _, rate)) => Some(*rate),
}
}
fn compute_price_table(
&self,
price_with: u64,
date: NaiveDate,
) -> HashMap<Commodity, (usize, TimeDelta, Decimal)> {
// minimize the distance, and then minimize the staleness.
let mut queue: BinaryHeap<(usize, TimeDelta, u64, Decimal)> = BinaryHeap::new();
let mut distances: HashMap<Commodity, (usize, TimeDelta, Decimal)> = HashMap::new();
let mut prevs: HashMap<u64, u64> = HashMap::new();
queue.push((0, TimeDelta::zero(), price_with, Decimal::ONE));
while let Some(curr) = queue.pop() {
let (curr_dist, staleness, prev, prev_rate) = curr;
log::debug!("curr: {:?}", curr);
if let Some(prev_dist) = distances.get(&prev) {
if (prev_dist.0, prev_dist.1) < (curr_dist, staleness) {
// no need to update, as it's worse than discovered path.
log::debug!(
"no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
prev_dist,
curr
);
continue;
}
}
for (j, rates) in match self.records.get(&prev) {
None => continue,
Some(x) => x,
} {
let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
log::debug!("found next commodity {} bound {}", j, bound);
if bound == 0 {
continue;
}
let (record_date, rate) = rates[bound - 1];
let staleness = std::cmp::max(staleness, date - record_date);
let rate = prev_rate * rate;
let next_dist = (curr_dist + 1, staleness, rate);
let next = (curr_dist + 1, staleness, *j, rate);
let updated = match distances.entry(*j) {
hash_map::Entry::Occupied(mut e) => {
if e.get() <= &next_dist {
false
} else {
e.insert(next_dist);
prevs.insert(*j, prev);
true
}
}
hash_map::Entry::Vacant(e) => {
e.insert(next_dist);
prevs.insert(*j, prev);
true
}
};
if !updated {
continue;
}
queue.push(next);
}
}
distances
}
}
pub struct CachedPriceDB {
inner: PriceDBHashMap,
cache: HashMap<NaiveDate, HashMap<Commodity, (usize, TimeDelta, Decimal)>>,
}
impl CachedPriceDB {
pub fn new(inner: PriceDBHashMap) -> Self {
Self {
inner,
cache: HashMap::new(),
}
}
pub fn compute_price(
&mut self,
price_of: u64,
price_with: u64,
date: NaiveDate,
) -> Option<Decimal> {
if price_of == price_with {
return Some(Decimal::ONE);
}
self.cache
.entry(date)
.or_insert_with(|| self.inner.compute_price_table(price_with, date))
.get(&price_of)
.map(|(_, _, rate)| *rate)
}
}
pub struct BTreeCachedPriceDB {
inner: PriceDBHashMap,
cache: BTreeMap<NaiveDate, HashMap<Commodity, (usize, TimeDelta, Decimal)>>,
}
impl BTreeCachedPriceDB {
pub fn new(inner: PriceDBHashMap) -> Self {
Self {
inner,
cache: BTreeMap::new(),
}
}
pub fn compute_price(
&mut self,
price_of: u64,
price_with: u64,
date: NaiveDate,
) -> Option<Decimal> {
if price_of == price_with {
return Some(Decimal::ONE);
}
self.cache
.entry(date)
.or_insert_with(|| self.inner.compute_price_table(price_with, date))
.get(&price_of)
.map(|(_, _, rate)| *rate)
}
}
pub struct FakeCommodityStore {
pub all: BTreeSet<u64>,
// dense has generally 1:1 mapping each other.
dense: BTreeSet<u64>,
// sparse commodity has only one link to the dense commodity.
sparse: BTreeMap<u64, u64>,
}
impl FakeCommodityStore {
fn new_commodity(&mut self) -> u64 {
let mut nc: u64;
loop {
nc = rand::random();
if self.all.insert(nc) {
break;
}
}
nc
}
pub fn new(dense_size: usize, sparse_scale: usize) -> Self {
let mut ret = Self {
all: Default::default(),
dense: Default::default(),
sparse: Default::default(),
};
for _ in 0..dense_size {
let dense = ret.new_commodity();
ret.dense.insert(dense);
for _ in 0..sparse_scale {
let sparse = ret.new_commodity();
ret.sparse.insert(sparse, dense);
}
}
ret
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[ctor::ctor]
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[test]
fn price_db_computes_direct_price() {
let mut db = PriceDBHashMap::default();
let chf: Commodity = 10;
let eur: Commodity = 15;
db.insert_last_price(
chf,
eur,
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
dec!(0.8),
);
db.insert_last_price(
eur,
chf,
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
Decimal::ONE / dec!(0.8),
);
let got = db.compute_price(eur, chf, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap());
assert_eq!(got, Some(dec!(0.8)));
let got = db.compute_price(chf, eur, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap());
assert_eq!(got, Some(dec!(1.25)));
let got = db.compute_price(chf, eur, NaiveDate::from_ymd_opt(2024, 9, 1).unwrap());
assert_eq!(got, None);
}
#[test]
fn price_db_computes_indirect_price() {
let mut db = PriceDBHashMap::default();
let chf: Commodity = 10;
let eur: Commodity = 15;
let usd: Commodity = 1001;
let jpy: Commodity = 1002456;
let date = NaiveDate::from_ymd_opt(2024, 10, 1).unwrap();
db.insert_last_price(chf, eur, date, dec!(0.8));
db.insert_last_price(eur, chf, date, Decimal::ONE / dec!(0.8));
db.insert_last_price(eur, usd, date, dec!(0.8));
db.insert_last_price(usd, eur, date, Decimal::ONE / dec!(0.8));
db.insert_last_price(jpy, usd, date, dec!(100));
db.insert_last_price(usd, jpy, date, Decimal::ONE / dec!(0.01));
// 1 EUR = 0.8 CHF
// 1 USD = 0.8 EUR
// 1 USD = 100 JPY
// 1 CHF == 5/4 EUR == (5/4)*(5/4) USD == 156.25 JPY
let got = db.compute_price(chf, jpy, NaiveDate::from_ymd_opt(2024, 10, 4).unwrap());
assert_eq!(got, Some(dec!(156.25)));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment