Created
June 26, 2019 11:47
-
-
Save andreastt/8707dbc69b4bad8392892c1c30b193b1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![allow(dead_code)] | |
#![allow(unused_imports)] | |
use serde::de::{self, SeqAccess, Visitor}; | |
use serde::{Deserialize, Deserializer, Serialize, Serializer}; | |
use serde_repr::{Serialize_repr, Deserialize_repr}; | |
use serde_json::{json, Map, Value}; | |
use std::borrow::Cow; | |
use std::fmt; | |
use std::marker::PhantomData; | |
use crate::MarionetteError; | |
use crate::webdriver; | |
use crate::webdriver::Command::*; | |
#[derive(Clone, Debug, PartialEq, Serialize_repr, Deserialize_repr)] | |
#[repr(u8)] | |
enum MessageDirection { | |
Incoming = 0, | |
Outgoing = 1, | |
} | |
type MessageId = u32; | |
type Payload = Map<String, Value>; | |
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] | |
#[serde(untagged)] | |
enum Command { | |
WebDriver(webdriver::Command), | |
} | |
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] | |
#[serde(untagged)] | |
enum Params { | |
Locator(webdriver::Locator), | |
None, | |
} | |
// TODO(ato): Merge with MessageDirection? (Maybe not possible?) | |
#[derive(Debug, PartialEq)] | |
enum Message { | |
Incoming(Request), | |
Outgoing(Response), | |
} | |
// TODO(ato): use a single Command(Params) here? | |
#[derive(Debug, PartialEq)] | |
struct Request { | |
id: MessageId, | |
name: Command, | |
params: Params, | |
} | |
impl Serialize for Request { | |
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
(MessageDirection::Incoming, self.id, &self.name, &self.params).serialize(serializer) | |
} | |
} | |
#[derive(Debug, PartialEq)] | |
enum Response { | |
Result { | |
id: MessageId, | |
result: Payload, | |
}, | |
Error { | |
id: MessageId, | |
error: MarionetteError, | |
stacktrace: Payload, | |
}, | |
} | |
impl Serialize for Response { | |
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
match self { | |
Response::Result { id, result } => (MessageDirection::Outgoing, id, Value::Null, &result).serialize(serializer), | |
Response::Error { id, error, stacktrace } => (MessageDirection::Outgoing, id, &error, &stacktrace).serialize(serializer), | |
} | |
} | |
} | |
impl Serialize for Message { | |
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
where | |
S: Serializer, | |
{ | |
match self { | |
Message::Incoming(ref req) => req.serialize(serializer), | |
Message::Outgoing(ref resp) => resp.serialize(serializer), | |
} | |
} | |
} | |
struct MessageVisitor; | |
// TODO(ato): as above, separate out for each type: | |
impl<'de> Visitor<'de> for MessageVisitor { | |
type Value = Message; | |
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { | |
formatter.write_str("four-element array") | |
} | |
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> { | |
let direction = seq.next_element::<MessageDirection>()? | |
.ok_or_else(|| de::Error::invalid_length(0, &self))?; | |
let msg = match direction { | |
MessageDirection::Incoming => { | |
let id: MessageId = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(1, &self))?; | |
let name: Command = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(2, &self))?; | |
let params: Params = match name { | |
Command::WebDriver(ref cmd) => match cmd { | |
FindElement => Params::Locator(seq.next_element::<webdriver::Locator>()?.unwrap()), | |
GetTimeouts => Params::None, | |
}, | |
}; | |
Message::Incoming(Request { id, name, params }) | |
} | |
MessageDirection::Outgoing => { | |
let id: MessageId = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(1, &self))?; | |
let maybe_error: Option<MarionetteError> = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(2, &self))?; | |
let response = if let Some(error) = maybe_error { | |
let stacktrace: Payload = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(3, &self))?; | |
Response::Error { id, error, stacktrace } | |
} else { | |
let result: Payload = seq.next_element()? | |
.ok_or_else(|| de::Error::invalid_length(3, &self))?; | |
Response::Result { id, result } | |
}; | |
Message::Outgoing(response) | |
} | |
}; | |
dbg!(&msg); | |
Ok(msg) | |
} | |
} | |
impl<'de> Deserialize<'de> for Message { | |
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | |
where | |
D: Deserializer<'de>, | |
{ | |
deserializer.deserialize_seq(MessageVisitor {}) | |
} | |
} | |
fn main() { | |
let json = r#"[1, 42, "no such element", {"foo":"bar"}]"#; | |
let msg: Message = serde_json::from_str(json).unwrap(); | |
dbg!(msg); | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use crate::test::check_serialize_deserialize; | |
use crate::webdriver::Command::*; | |
#[test] | |
fn test_incoming() { | |
let json = r#"[0,42,"WebDriver:FindElement",{"using":"css selector","value":"body"}]"#; | |
let request = Request { | |
id: 42, | |
name: Command::WebDriver(FindElement), | |
params: Params::Locator(webdriver::Locator { using: webdriver::Selector::CSS, value: "body".into() }), | |
}; | |
let msg = Message::Incoming(request); | |
check_serialize_deserialize(&json, &msg); | |
} | |
#[test] | |
fn test_outgoing() { | |
let json = r#"[1,42,"no such element",{}]"#; | |
let msg = Message::Outgoing(Response::Error { | |
id: 42, | |
error: MarionetteError::NoSuchElement, | |
stacktrace: Map::new(), | |
}); | |
check_serialize_deserialize(&json, &msg); | |
} | |
#[test] | |
fn test_invalid_type() { | |
assert!(serde_json::from_str::<Message>(r#"[2,42,"WebDriver:GetTimeouts",{}]"#).is_err()); | |
assert!(serde_json::from_str::<Message>(r#"[3,42,"no such element",{}]"#).is_err()); | |
} | |
#[test] | |
fn test_missing_fields() { | |
// all fields are required | |
assert!(serde_json::from_str::<Message>(r#"[2,42,"WebDriver:GetTimeouts"]"#).is_err()); | |
assert!(serde_json::from_str::<Message>(r#"[2,42]"#).is_err()); | |
assert!(serde_json::from_str::<Message>(r#"[2]"#).is_err()); | |
assert!(serde_json::from_str::<Message>(r#"[]"#).is_err()); | |
} | |
#[test] | |
fn test_unknown_command() { | |
assert!(serde_json::from_str::<Message>(r#"[0,42,"hooba",{}]"#).is_err()); | |
} | |
#[test] | |
fn test_unknown_error() { | |
assert!(serde_json::from_str::<Message>(r#"[1,42,"flooba",{}]"#).is_err()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment