Skip to content

Instantly share code, notes, and snippets.

@jerry73204
Created June 3, 2020 19:09
Show Gist options
  • Save jerry73204/09a74b92635ff954fde5a47adddc780c to your computer and use it in GitHub Desktop.
Save jerry73204/09a74b92635ff954fde5a47adddc780c to your computer and use it in GitHub Desktop.
serde on tch-rs types example
use failure::Fallible;
use serde::{
de::Error as DeserializeError, ser::Error as SerializeError, Deserialize, Deserializer,
Serialize, Serializer,
};
use std::fs;
use tch::{Device, Kind, Tensor};
fn main() -> Fallible<()> {
let json = fs::read_to_string("example.json")?;
let example: Example = serde_json::from_str(&json)?;
let pretty_json = serde_json::to_string_pretty(&example)?;
println!("{:?}", example);
println!("{}", pretty_json);
Ok(())
}
#[derive(Debug, Serialize, Deserialize)]
struct Example {
#[serde(
serialize_with = "serialize_tensor",
deserialize_with = "deserialize_tensor"
)]
pub tensor: Tensor,
#[serde(
serialize_with = "serialize_kind",
deserialize_with = "deserialize_kind"
)]
pub kind: Kind,
#[serde(
serialize_with = "serialize_device",
deserialize_with = "deserialize_device"
)]
pub device: Device,
}
#[derive(Debug, Serialize, Deserialize)]
struct TensorRepr {
pub requires_grad: bool,
#[serde(
serialize_with = "serialize_device",
deserialize_with = "deserialize_device"
)]
pub device: Device,
pub shape: Vec<i64>,
pub data: DataKind,
}
#[derive(Debug, Serialize, Deserialize)]
enum DataKind {
#[serde(rename = "uint8")]
Uint8(Vec<u8>),
#[serde(rename = "int8")]
Int8(Vec<i8>),
#[serde(rename = "int16")]
Int16(Vec<i16>),
#[serde(rename = "int")]
Int(Vec<i32>),
#[serde(rename = "int64")]
Int64(Vec<i64>),
// Half(Vec<f16>),
#[serde(rename = "float")]
Float(Vec<f32>),
#[serde(rename = "double")]
Double(Vec<f64>),
// ComplexHalf(Vec<>),
// ComplexFloat(Vec<>),
// ComplexDouble(Vec<>),
#[serde(rename = "bool")]
Bool(Vec<bool>),
}
fn serialize_tensor<S>(tensor: &Tensor, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let device = tensor.device();
let requires_grad = tensor.requires_grad();
let shape = tensor.size();
let kind = tensor.kind();
let data = match kind {
Kind::Uint8 => DataKind::Uint8(Vec::<u8>::from(tensor)),
Kind::Int8 => DataKind::Int8(Vec::<i8>::from(tensor)),
Kind::Int => DataKind::Int(Vec::<i32>::from(tensor)),
Kind::Int64 => DataKind::Int64(Vec::<i64>::from(tensor)),
Kind::Float => DataKind::Float(Vec::<f32>::from(tensor)),
Kind::Double => DataKind::Double(Vec::<f64>::from(tensor)),
Kind::Bool => DataKind::Bool(Vec::<bool>::from(tensor)),
_ => {
return Err(S::Error::custom(format!(
"the kind {:?} is not supported yet",
kind
)))
}
};
let repr = TensorRepr {
requires_grad,
device,
shape,
data,
};
repr.serialize(serializer)
}
fn deserialize_tensor<'de, D>(deserializer: D) -> Result<Tensor, D::Error>
where
D: Deserializer<'de>,
{
let TensorRepr {
requires_grad,
device,
shape,
data,
} = Deserialize::deserialize(deserializer)?;
let tensor = match data {
DataKind::Uint8(v) => Tensor::of_slice(&v),
DataKind::Int8(v) => Tensor::of_slice(&v),
DataKind::Int(v) => Tensor::of_slice(&v),
DataKind::Int64(v) => Tensor::of_slice(&v),
DataKind::Float(v) => Tensor::of_slice(&v),
DataKind::Double(v) => Tensor::of_slice(&v),
DataKind::Bool(v) => Tensor::of_slice(&v),
_ => return Err(D::Error::custom("unimplemented")),
};
let tensor = tensor.view(shape.as_slice());
let tensor = tensor.set_requires_grad(requires_grad);
let tensor = tensor.to_device(device);
Ok(tensor)
}
fn serialize_device<S>(device: &Device, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let text = match device {
Device::Cpu => "cpu".into(),
Device::Cuda(n) => format!("cuda({})", n),
};
serializer.serialize_str(&text)
}
fn deserialize_device<'de, D>(deserializer: D) -> Result<Device, D::Error>
where
D: Deserializer<'de>,
{
let text = String::deserialize(deserializer)?;
let device = match text.as_str() {
"cpu" => Device::Cpu,
_ => {
let prefix = "cuda(";
let suffix = ")";
if text.starts_with(prefix) && text.ends_with(suffix) {
let number: usize = text[(prefix.len())..(text.len() - suffix.len())]
.parse()
.map_err(|_err| D::Error::custom(format!("invalid device name {}", text)))?;
Device::Cuda(number)
} else {
return Err(D::Error::custom(""));
}
}
};
Ok(device)
}
fn serialize_kind<S>(kind: &Kind, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use Kind::*;
let text = match kind {
Uint8 => "uint8",
Int8 => "int8",
Int16 => "int16",
Int => "int",
Int64 => "int64",
Half => "half",
Float => "float",
Double => "double",
ComplexHalf => "complex_half",
ComplexFloat => "complex_float",
ComplexDouble => "complex_double",
Bool => "bool",
};
text.serialize(serializer)
}
fn deserialize_kind<'de, D>(deserializer: D) -> Result<Kind, D::Error>
where
D: Deserializer<'de>,
{
use Kind::*;
let text = String::deserialize(deserializer)?;
let kind = match text.as_str() {
"uint8" => Uint8,
"int8" => Int8,
"int16" => Int16,
"int" => Int,
"int64" => Int64,
"half" => Half,
"float" => Float,
"double" => Double,
"complex_half" => ComplexHalf,
"complex_float" => ComplexFloat,
"complex_double" => ComplexDouble,
"bool" => Bool,
_ => return Err(D::Error::custom(format!(r#"invalid kind "{}""#, text))),
};
Ok(kind)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment