Skip to content

Instantly share code, notes, and snippets.

@andoriyu
Created June 19, 2017 20:45
Show Gist options
  • Save andoriyu/b57f329a7339c67d10e225c933c5f9cb to your computer and use it in GitHub Desktop.
Save andoriyu/b57f329a7339c67d10e225c933c5f9cb to your computer and use it in GitHub Desktop.
use errors::{AWResult, AWErrorKind};
use llsd::frames::{Frame, FrameKind};
use llsd::session::Sendable;
use llsd::session::server::Session;
use sodiumoxide::crypto::box_::{SecretKey, PublicKey};
use std::sync::{Arc, RwLock};
use system::{Handler, ServiceHub};
use system::authenticator::Authenticator;
use system::sessionstore::SessionStore;
use typemap::TypeMap;
pub struct AngelSystem<S: SessionStore, A: Authenticator, H: Handler> {
sessions: S,
authenticator: A,
public_key: PublicKey,
secret_key: SecretKey,
services: ServiceHub,
handler: Arc<H>,
}
impl<S: SessionStore, A: Authenticator, H: Handler> Clone for AngelSystem<S, A, H> {
fn clone(&self) -> AngelSystem<S, A, H> {
AngelSystem {
sessions: self.sessions.clone(),
authenticator: self.authenticator.clone(),
public_key: self.public_key,
secret_key: self.secret_key.clone(),
services: self.services.clone(),
handler: self.handler.clone(),
}
}
}
impl<S: SessionStore, A: Authenticator, H: Handler> AngelSystem<S, A, H> {
pub fn new(store: S,
authenticator: A,
pk: PublicKey,
sk: SecretKey,
handler: H)
-> AngelSystem<S, A, H> {
AngelSystem {
sessions: store,
authenticator: authenticator,
public_key: pk,
secret_key: sk,
services: Arc::new(RwLock::new(TypeMap::custom())),
handler: Arc::new(handler),
}
}
pub fn process(&self, req: Frame) -> AWResult<Frame> {
match req.kind {
FrameKind::Hello => self.process_hello(&req),
FrameKind::Initiate => self.process_initiate(&req),
FrameKind::Message => self.process_message(&req),
_ => unimplemented!(),
}
}
fn process_hello(&self, frame: &Frame) -> AWResult<Frame> {
// Verify it's a new session
if self.sessions.find_by_pk(&frame.id).is_some() {
fail!(AWErrorKind::IncorrectState);
}
let session = Session::new(frame.id);
// If inserting session failed - bail out early.
if self.sessions.insert(session).is_none() {
fail!(AWErrorKind::ServerFault);
}
if let Some(session_lock) = self.sessions.find_by_pk(&frame.id) {
let session_guard = session_lock.write();
if let Ok(mut session) = session_guard {
match session.make_welcome(frame, &self.secret_key) {
Ok(frame) => return Ok(frame),
Err(e) => fail!(AWErrorKind::HandshakeFailed(Some(e))),
}
}
} else {
fail!(AWErrorKind::HandshakeFailed(None))
}
fail!(AWErrorKind::ServerFault);
}
// TODO: Rewrite this madness
fn process_initiate(&self, frame: &Frame) -> AWResult<Frame> {
match self.sessions.find_by_pk(&frame.id) {
None => fail!(AWErrorKind::IncorrectState),
Some(session_lock) => {
let session_guard = session_lock.write();
if let Ok(mut session) = session_guard {
match session.validate_initiate(frame) {
None => fail!(AWErrorKind::HandshakeFailed(None)),
Some(key) => {
if !self.authenticator.is_valid(&key) {
fail!(AWErrorKind::HandshakeFailed(None));
}
match session.make_ready(frame, &key) {
Ok(res) => Ok(res),
Err(err) => fail!(AWErrorKind::HandshakeFailed(Some(err))),
}
}
}
} else {
// Failed to aquire write lock for a session.
fail!(AWErrorKind::ServerFault);
}
}
}
}
fn process_message(&self, frame: &Frame) -> AWResult<Frame> {
let session_lock = match self.sessions.find_by_pk(&frame.id) {
None => fail!(AWErrorKind::IncorrectState),
Some(session_lock) => session_lock,
};
let req = {
let session = match session_lock.read() {
Err(_) => fail!(AWErrorKind::ServerFault),
Ok(session) => session,
};
match session.read_msg(frame) {
None => fail!(AWErrorKind::CannotDecrypt),
Some(req) => req.to_vec(),
}
};
// this is going to take Arc<RWLock<Session>> as argument.
let res = try!(self.handler
.handle(self.services.clone(), session_lock.clone(), req.to_vec()));
let session = match session_lock.read() {
Err(_) => fail!(AWErrorKind::ServerFault),
Ok(session) => session,
};
session
.make_message(&res)
.map_err(|_| AWErrorKind::BadFrame.into())
}
}
/*
* System On Tokio
*/
#[cfg(feature = "system-on-tokio")]
pub mod tokio {
use super::{AngelSystem, Handler, SessionStore, Authenticator, AWErrorKind};
use frames::Frame;
use futures::{future, Future, BoxFuture};
use std::io;
use std::sync::Arc;
use tokio_service::Service;
pub struct InlineService<S: SessionStore, A: Authenticator, H: Handler> {
system: Arc<AngelSystem<S, A, H>>,
}
impl<S: SessionStore, A: Authenticator, H: Handler> Clone for InlineService<S, A, H> {
fn clone(&self) -> InlineService<S, A, H> {
InlineService::new(self.system.clone())
}
}
impl<S: SessionStore, A: Authenticator, H: Handler> InlineService<S, A, H> {
pub fn new(system: Arc<AngelSystem<S, A, H>>) -> InlineService<S, A, H> {
InlineService { system: system.clone() }
}
}
impl<S: SessionStore, A: Authenticator, H: Handler> Service for InlineService<S, A, H> {
type Request = Frame;
type Response = Frame;
type Error = io::Error;
type Future = BoxFuture<Self::Response, Self::Error>;
fn call(&self, req: Self::Request) -> Self::Future {
match self.system.process(req) {
Ok(res) => future::ok(res).boxed(),
Err(err) => {
match *err {
AWErrorKind::ServerFault => {
future::err(io::Error::new(io::ErrorKind::Other, err)).boxed()
}
_ => unimplemented!(),
}
}
}
}
}
}
use byteorder::{BigEndian, ByteOrder};
use bytes::{BytesMut, BufMut};
use frames::Frame;
use llsd::errors::LlsdErrorKind;
use llsd::session::client::Session as ClientSession;
use std::io;
use std::result::Result;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::codec::{Encoder, Decoder, Framed};
use tokio_proto::pipeline::{ServerProto, ClientProto};
use tokio_service::Service;
pub struct FrameCodec;
impl Decoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Frame>> {
// Check that if we have at least 4 bytes to read
if buf.len() < 4 {
return Ok(None);
}
// Check that if we have the whole payload
let payload_len = BigEndian::read_u32(&buf[0..4]) as usize;
if buf.len() < 4 + payload_len {
return Ok(None);
}
// We have a whole frame. Consume those bytes form the buffer.
let data = buf.split_to(4 + payload_len);
match Frame::from_slice(&data[4..]) {
Ok(frame) => Ok(Some(frame)),
Err(e) => {
if *e == LlsdErrorKind::IncompleteFrame {
Ok(None)
} else {
Err(e.into())
}
}
}
}
}
impl Encoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn encode(&mut self, msg: Frame, buf: &mut BytesMut) -> io::Result<()> {
if buf.remaining_mut() < 4 {
buf.reserve(4);
}
buf.put_u32::<BigEndian>(msg.length() as u32);
msg.pack_to_buf(buf);
Ok(())
}
}
pub struct WhisperPipelinedProtocol;
impl<T: AsyncRead + AsyncWrite + 'static> ServerProto<T> for WhisperPipelinedProtocol {
type Request = Frame;
type Response = Frame;
type Transport = Framed<T, FrameCodec>;
type BindTransport = Result<Self::Transport, io::Error>;
fn bind_transport(&self, io: T) -> Self::BindTransport {
Ok(io.framed(FrameCodec))
}
}
impl<T: AsyncRead + AsyncWrite + 'static> ClientProto<T> for WhisperPipelinedProtocol {
type Request = Frame;
type Response = Frame;
type Transport = Framed<T, FrameCodec>;
type BindTransport = Result<Self::Transport, io::Error>;
fn bind_transport(&self, io: T) -> Self::BindTransport {
Ok(io.framed(FrameCodec))
}
}
pub trait ClientHandshakeHelper {
fn authenticate(&self, session: &mut ClientSession) -> Result<(), io::Error>;
}
// THIS, I want this implemented for L173 in angel_system.rs
impl<T> ClientHandshakeHelper for T
where T: Service
{
fn authenticate(&self, session: &mut ClientSession) -> Result<(), io::Error> {
let hello_frame = session.make_hello();
let hello_req = self.call(hello_frame);
unimplemented!()
}
}
#[cfg(test)]
mod test {
use super::*;
use frames::FrameKind;
use sodiumoxide::crypto::box_::{gen_keypair, gen_nonce};
fn make_frame() -> Frame {
let (pk, _) = gen_keypair();
let payload = vec![0, 0, 0];
let nonce = gen_nonce();
Frame {
id: pk,
nonce: nonce,
kind: FrameKind::Hello,
payload: payload,
}
}
#[test]
fn test_decode() {
let mut buf = BytesMut::with_capacity(70);
let frame = make_frame();
let mut codec = FrameCodec {};
// First let's test if it can handle missing len
let result = codec.decode(&mut buf);
assert_eq!(0, buf.len());
assert!(result.is_ok());
assert!(result.unwrap().is_none());
buf.put_u32::<BigEndian>(frame.length() as u32);
// Message has just header
let result = codec.decode(&mut buf);
assert_eq!(4, buf.len());
assert!(result.is_ok());
assert!(result.unwrap().is_none());
frame.pack_to_buf(&mut buf);
// Message is partial
let mut buf_partial = BytesMut::from(&buf[0..30]);
let result = codec.decode(&mut buf_partial);
assert_eq!(30, buf_partial.len());
assert!(result.is_ok());
assert!(result.unwrap().is_none());
// Message is fully available
let result = codec.decode(&mut buf);
assert_eq!(0, buf.len());
assert!(result.is_ok());
assert!(result.unwrap().is_some());
buf.put_u32::<BigEndian>(frame.length() as u32);
frame.pack_to_buf(&mut buf);
buf.put_u32::<BigEndian>(frame.length() as u32);
// Two messages at once
let result = codec.decode(&mut buf);
assert_eq!(4, buf.len());
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}
#[test]
fn test_encode() {
let frame = make_frame();
let mut buf = BytesMut::new();
let mut codec = FrameCodec {};
let result = codec.encode(frame.clone(), &mut buf);
assert!(result.is_ok());
let payload_len = BigEndian::read_u32(&buf[0..4]) as usize;
assert_eq!(frame.length(), payload_len);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment