Skip to content

Instantly share code, notes, and snippets.

@ldesgoui
Last active July 23, 2021 09:10
Show Gist options
  • Save ldesgoui/181d1749f6fd42f2d70b32734e2bc4ee to your computer and use it in GitHub Desktop.
Save ldesgoui/181d1749f6fd42f2d70b32734e2bc4ee to your computer and use it in GitHub Desktop.
dynamic deser
#![allow(unused)]
use erased_serde as erased;
use once_cell::race::OnceBox;
use serde::de::{self, DeserializeSeed, Deserializer, MapAccess, Visitor};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
fn main() {}
type Store = HashMap<TypeId, Box<dyn Any + Send + Sync>>;
struct Registration {
field: &'static str,
type_id: TypeId,
deserialize_fn: DeserializeFn,
}
type DeserializeFn =
fn(&mut dyn erased::Deserializer) -> erased::Result<Box<dyn Any + Send + Sync>>;
struct StoreSeed<'a> {
store: Store,
registrations: HashMap<&'static str, &'a Registration>,
}
impl<'a> StoreSeed<'a> {
fn new(registrations: &'a [Registration]) -> Self {
Self {
store: Store::with_capacity(registrations.len()),
registrations: registrations.iter().map(|r| (r.field, r)).collect(),
}
}
}
impl<'de> DeserializeSeed<'de> for StoreSeed<'_> {
type Value = Store;
fn deserialize<D>(mut self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(&mut self)?;
Ok(self.store)
}
}
impl<'de> Visitor<'de> for &mut StoreSeed<'_> {
type Value = ();
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("Configuration")
}
fn visit_map<A>(mut self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
while let Some(key) = map.next_key::<String>()? {
if let Some(registration) = self.registrations.get(&*key) {
let x = map.next_value_seed(CoolSeed(&registration.deserialize_fn))?;
if self.store.insert(registration.type_id, x).is_some() {
return Err(de::Error::duplicate_field(registration.field));
}
}
}
Ok(())
}
}
struct CoolSeed<T>(T);
impl<'de> DeserializeSeed<'de> for CoolSeed<&DeserializeFn> {
type Value = Box<dyn Any + Send + Sync>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let mut erased = <dyn erased::Deserializer>::erase(deserializer);
(self.0)(&mut erased).map_err(serde::de::Error::custom)
}
}
static STORE: OnceBox<Store> = OnceBox::new();
#[doc(hidden)]
pub fn __store_get<T>() -> &'static T {
STORE
.get()
.expect("init not called soon enough")
.get(&TypeId::of::<T>())
.unwrap()
.downcast_ref()
.unwrap()
}
#[linkme::distributed_slice]
static REGISTRATIONS: [Registration] = [..];
pub fn init() {
// TODO: read config from outside xD
let de = serde_json::json!({
"mysql": {
"host": "localhost:5432",
"database": "test",
"user": "root",
"password": "toor",
},
"listen_port": 8080,
});
let seed = StoreSeed::new(&REGISTRATIONS);
let store = seed.deserialize(de).unwrap();
if STORE.set(Box::new(store)).is_err() {
panic!("init called once too many times");
}
}
///
/// ```rust
/// #[derive(Debug, serde::Deserialize)]
/// struct Mysql {
/// host: String,
/// database: String,
/// user: String,
/// password: String,
/// }
///
/// register!("mysql", Mysql)
///
/// // vv Generated vv
///
/// #[linkme::distributed_slice(REGISTRATIONS)]
/// static MYSQL: Registration = Registration {
/// field: "mysql",
/// type_id: TypeId::of::<MysqlConnection>(),
/// deserialize_fn: |d| Ok(Box::new(erased::deserialize::<MysqlConnection>(d)?)),
/// };
///
/// impl Mysql {
/// fn get() -> &'static Self {
/// __store_get()
/// }
/// }
///
/// ```
fn __dummy() {}
#[cfg(test)]
mod tests {
use super::*;
use erased_serde as erased;
use serde_json as json;
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct MysqlConnection {
host: String,
database: String,
user: String,
password: String,
}
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct ListenPort(u16);
fn registrations() -> [Registration; 2] {
[
Registration {
field: "mysql",
type_id: TypeId::of::<MysqlConnection>(),
deserialize_fn: |d| Ok(Box::new(erased::deserialize::<MysqlConnection>(d)?)),
},
Registration {
field: "listen_port",
type_id: TypeId::of::<ListenPort>(),
deserialize_fn: |d| Ok(Box::new(erased::deserialize::<ListenPort>(d)?)),
},
]
}
#[test]
fn hello() {
let de = json::json!({
"mysql": {
"host": "localhost:5432",
"database": "test",
"user": "root",
"password": "toor",
},
"listen_port": 8080,
});
let registrations = registrations();
let seed = StoreSeed::new(&registrations);
let store = seed.deserialize(de).unwrap();
assert_eq!(store.len(), registrations.len());
assert_eq!(
store
.get(&TypeId::of::<MysqlConnection>())
.unwrap()
.downcast_ref::<MysqlConnection>()
.unwrap(),
&MysqlConnection {
host: "localhost:5432".into(),
database: "test".into(),
user: "root".into(),
password: "toor".into()
}
);
assert_eq!(
store
.get(&TypeId::of::<ListenPort>())
.unwrap()
.downcast_ref::<ListenPort>()
.unwrap(),
&ListenPort(8080),
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment