Created
June 3, 2020 19:09
-
-
Save jerry73204/09a74b92635ff954fde5a47adddc780c to your computer and use it in GitHub Desktop.
serde on tch-rs types example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use failure::Fallible; | |
use serde::{ | |
de::Error as DeserializeError, ser::Error as SerializeError, Deserialize, Deserializer, | |
Serialize, Serializer, | |
}; | |
use std::fs; | |
use tch::{Device, Kind, Tensor}; | |
fn main() -> Fallible<()> { | |
let json = fs::read_to_string("example.json")?; | |
let example: Example = serde_json::from_str(&json)?; | |
let pretty_json = serde_json::to_string_pretty(&example)?; | |
println!("{:?}", example); | |
println!("{}", pretty_json); | |
Ok(()) | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
struct Example { | |
#[serde( | |
serialize_with = "serialize_tensor", | |
deserialize_with = "deserialize_tensor" | |
)] | |
pub tensor: Tensor, | |
#[serde( | |
serialize_with = "serialize_kind", | |
deserialize_with = "deserialize_kind" | |
)] | |
pub kind: Kind, | |
#[serde( | |
serialize_with = "serialize_device", | |
deserialize_with = "deserialize_device" | |
)] | |
pub device: Device, | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
struct TensorRepr { | |
pub requires_grad: bool, | |
#[serde( | |
serialize_with = "serialize_device", | |
deserialize_with = "deserialize_device" | |
)] | |
pub device: Device, | |
pub shape: Vec<i64>, | |
pub data: DataKind, | |
} | |
#[derive(Debug, Serialize, Deserialize)] | |
enum DataKind { | |
#[serde(rename = "uint8")] | |
Uint8(Vec<u8>), | |
#[serde(rename = "int8")] | |
Int8(Vec<i8>), | |
#[serde(rename = "int16")] | |
Int16(Vec<i16>), | |
#[serde(rename = "int")] | |
Int(Vec<i32>), | |
#[serde(rename = "int64")] | |
Int64(Vec<i64>), | |
// Half(Vec<f16>), | |
#[serde(rename = "float")] | |
Float(Vec<f32>), | |
#[serde(rename = "double")] | |
Double(Vec<f64>), | |
// ComplexHalf(Vec<>), | |
// ComplexFloat(Vec<>), | |
// ComplexDouble(Vec<>), | |
#[serde(rename = "bool")] | |
Bool(Vec<bool>), | |
} | |
fn serialize_tensor<S>(tensor: &Tensor, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
let device = tensor.device(); | |
let requires_grad = tensor.requires_grad(); | |
let shape = tensor.size(); | |
let kind = tensor.kind(); | |
let data = match kind { | |
Kind::Uint8 => DataKind::Uint8(Vec::<u8>::from(tensor)), | |
Kind::Int8 => DataKind::Int8(Vec::<i8>::from(tensor)), | |
Kind::Int => DataKind::Int(Vec::<i32>::from(tensor)), | |
Kind::Int64 => DataKind::Int64(Vec::<i64>::from(tensor)), | |
Kind::Float => DataKind::Float(Vec::<f32>::from(tensor)), | |
Kind::Double => DataKind::Double(Vec::<f64>::from(tensor)), | |
Kind::Bool => DataKind::Bool(Vec::<bool>::from(tensor)), | |
_ => { | |
return Err(S::Error::custom(format!( | |
"the kind {:?} is not supported yet", | |
kind | |
))) | |
} | |
}; | |
let repr = TensorRepr { | |
requires_grad, | |
device, | |
shape, | |
data, | |
}; | |
repr.serialize(serializer) | |
} | |
fn deserialize_tensor<'de, D>(deserializer: D) -> Result<Tensor, D::Error> | |
where | |
D: Deserializer<'de>, | |
{ | |
let TensorRepr { | |
requires_grad, | |
device, | |
shape, | |
data, | |
} = Deserialize::deserialize(deserializer)?; | |
let tensor = match data { | |
DataKind::Uint8(v) => Tensor::of_slice(&v), | |
DataKind::Int8(v) => Tensor::of_slice(&v), | |
DataKind::Int(v) => Tensor::of_slice(&v), | |
DataKind::Int64(v) => Tensor::of_slice(&v), | |
DataKind::Float(v) => Tensor::of_slice(&v), | |
DataKind::Double(v) => Tensor::of_slice(&v), | |
DataKind::Bool(v) => Tensor::of_slice(&v), | |
_ => return Err(D::Error::custom("unimplemented")), | |
}; | |
let tensor = tensor.view(shape.as_slice()); | |
let tensor = tensor.set_requires_grad(requires_grad); | |
let tensor = tensor.to_device(device); | |
Ok(tensor) | |
} | |
fn serialize_device<S>(device: &Device, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
let text = match device { | |
Device::Cpu => "cpu".into(), | |
Device::Cuda(n) => format!("cuda({})", n), | |
}; | |
serializer.serialize_str(&text) | |
} | |
fn deserialize_device<'de, D>(deserializer: D) -> Result<Device, D::Error> | |
where | |
D: Deserializer<'de>, | |
{ | |
let text = String::deserialize(deserializer)?; | |
let device = match text.as_str() { | |
"cpu" => Device::Cpu, | |
_ => { | |
let prefix = "cuda("; | |
let suffix = ")"; | |
if text.starts_with(prefix) && text.ends_with(suffix) { | |
let number: usize = text[(prefix.len())..(text.len() - suffix.len())] | |
.parse() | |
.map_err(|_err| D::Error::custom(format!("invalid device name {}", text)))?; | |
Device::Cuda(number) | |
} else { | |
return Err(D::Error::custom("")); | |
} | |
} | |
}; | |
Ok(device) | |
} | |
fn serialize_kind<S>(kind: &Kind, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
use Kind::*; | |
let text = match kind { | |
Uint8 => "uint8", | |
Int8 => "int8", | |
Int16 => "int16", | |
Int => "int", | |
Int64 => "int64", | |
Half => "half", | |
Float => "float", | |
Double => "double", | |
ComplexHalf => "complex_half", | |
ComplexFloat => "complex_float", | |
ComplexDouble => "complex_double", | |
Bool => "bool", | |
}; | |
text.serialize(serializer) | |
} | |
fn deserialize_kind<'de, D>(deserializer: D) -> Result<Kind, D::Error> | |
where | |
D: Deserializer<'de>, | |
{ | |
use Kind::*; | |
let text = String::deserialize(deserializer)?; | |
let kind = match text.as_str() { | |
"uint8" => Uint8, | |
"int8" => Int8, | |
"int16" => Int16, | |
"int" => Int, | |
"int64" => Int64, | |
"half" => Half, | |
"float" => Float, | |
"double" => Double, | |
"complex_half" => ComplexHalf, | |
"complex_float" => ComplexFloat, | |
"complex_double" => ComplexDouble, | |
"bool" => Bool, | |
_ => return Err(D::Error::custom(format!(r#"invalid kind "{}""#, text))), | |
}; | |
Ok(kind) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment