Created
January 4, 2019 19:00
-
-
Save ayende/563813f0bf9b7eb7dcc5ec2d2b83942f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate custom_error; | |
extern crate memmem; | |
extern crate openssl; | |
extern crate openssl_sys; | |
extern crate hex; | |
extern crate foreign_types_shared; | |
extern crate tokio; | |
extern crate tokio_openssl; | |
#[macro_use] extern crate lazy_static; | |
use std::io; | |
use std::io::BufReader; | |
use std::collections::HashMap; | |
use std::collections::HashSet; | |
use custom_error::custom_error; | |
use memmem::{Searcher, TwoWaySearcher}; | |
use foreign_types_shared::ForeignTypeRef; | |
use tokio::io::{write_all}; | |
use tokio::net::TcpListener; | |
use tokio::prelude::*; | |
use tokio_openssl::SslAcceptorExt; | |
use tokio::io::ReadHalf; | |
use tokio::io::lines; | |
use tokio::net::tcp::TcpStream; | |
custom_error! { | |
ConnectionError | |
AddrParseError{source: std::net::AddrParseError} = "Unable to parse address {source}", | |
Io{source: io::Error} = "unable to read from the network", | |
Utf8{source: std::str::Utf8Error} = "Invalid UTF8 character sequence", | |
Parse{origin: String} = "Unable to parse command: {origin}", | |
MessageTooBig = "Message length was over 8KB", | |
SslIssue{source : openssl::error::ErrorStack} = "OpenSSL error {source}", | |
Handshake{source: openssl::ssl::HandshakeError<std::net::TcpStream>} = "Handshake error {source}", | |
InvalidTimeFormat = "Unable to understand certificate time", | |
ClientCertExpired{date: String} = "The client certificate has already expired: {date}", | |
ClientCertNotYetValid{date: String} = "The client certificate is not yet valid: {date}" | |
} | |
impl ConnectionError { | |
fn parsing(origin: &str) -> ConnectionError { | |
ConnectionError::Parse{ origin: origin.to_string() } | |
} | |
} | |
struct Cmd<'a> { | |
args: Vec<&'a str>, | |
headers: HashMap<&'a str, &'a str>, | |
} | |
lazy_static! { | |
static ref msg_break : TwoWaySearcher<'static> = { | |
TwoWaySearcher::new(b"\r\n\r\n") | |
}; | |
} | |
struct Server { | |
tls_config: openssl::ssl::SslAcceptor, | |
allowed_certs_thumbprints: HashSet<String> | |
} | |
impl Server { | |
fn new(cert_path: &str, key_path: &str, allowed_certs_thumbprints: &[&str]) -> Result<Server, ConnectionError> { | |
let mut allowed = HashSet::new(); | |
for thumbprint in allowed_certs_thumbprints { | |
allowed.insert(thumbprint.to_lowercase()); | |
} | |
let mut sslb = openssl::ssl::SslAcceptor::mozilla_modern(openssl::ssl::SslMethod::tls())?; | |
sslb.set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)?; | |
sslb.set_certificate_chain_file(cert_path)?; | |
sslb.check_private_key()?; | |
// accept all certificates, we'll do our own validation on them | |
sslb.set_verify_callback(openssl::ssl::SslVerifyMode::PEER, |_, _| true); | |
let server = Server { | |
tls_config: sslb.build(), | |
allowed_certs_thumbprints: allowed | |
}; | |
Ok(server) | |
} | |
} | |
fn parse_cmd<'a>(cmd_str: &'a str) -> Result<Cmd, ConnectionError> { | |
let mut lines = cmd_str.lines(); | |
let cmd_line = match lines.next() { | |
None => { | |
return Err(ConnectionError::parsing(cmd_str)); | |
} | |
Some(v) => v, | |
}; | |
let mut cmd = Cmd { | |
args: cmd_line.split(' ').collect(), | |
headers: HashMap::new(), | |
}; | |
for line in lines { | |
let parts: Vec<&str> = line.splitn(2, ':').collect(); | |
if parts.len() != 2 { | |
return Err(ConnectionError::parsing(line)); | |
} | |
cmd.headers.insert(parts[0].trim(), parts[1].trim()); | |
} | |
Ok(cmd) | |
} | |
fn read_full_message<'a>(stream: &mut io::Read, buffer: &'a mut Vec<u8>) -> Result<&'a [u8], ConnectionError> { | |
let mut to_scan = 0; | |
let mut tmp_buf = [0; 256]; | |
loop { | |
match msg_break.search_in(&buffer[to_scan..]) { | |
None => to_scan = if buffer.len() > 3 { buffer.len() - 3} else { 0 }, | |
Some(msg_end) => return Ok(&buffer[0..(to_scan + msg_end + 4)]) | |
} | |
let read = stream.read(&mut tmp_buf)?; | |
if read + buffer.len() > 8192 { | |
return Err(ConnectionError::MessageTooBig) | |
} | |
buffer.extend_from_slice(&tmp_buf[0..read]); | |
} | |
} | |
fn dispatch_cmd<S>(stream: &mut S, cmd : Cmd) -> io::Result<()> | |
where S : io::Write + io::Read { | |
stream.write(&cmd.args[0].as_bytes())?; | |
stream.flush()?; | |
Ok(()) | |
} | |
fn authenticate_certificate(ssl: &openssl::ssl::SslRef, server: &Server) -> Result<Option<String>, ConnectionError> { | |
fn get_friendly_name(peer: &openssl::x509::X509) -> String { | |
peer.subject_name() // can't figure out how to get the real friendly name | |
.entries() | |
.last() | |
.map( |it| it.data() | |
.as_utf8() | |
.and_then(|s| Ok(s.to_string())) | |
.unwrap_or("".to_string()) | |
) | |
.unwrap_or("<Unknown>".to_string()) | |
} | |
extern "C" { | |
fn ASN1_TIME_diff( | |
pday: *mut std::os::raw::c_int, | |
psec: *mut std::os::raw::c_int, | |
from: *const openssl_sys::ASN1_TIME, | |
to: *const openssl_sys::ASN1_TIME) -> std::os::raw::c_int; | |
} | |
fn is_before(x: &openssl::asn1::Asn1TimeRef, y: &openssl::asn1::Asn1TimeRef) -> Result<bool, ConnectionError> { | |
unsafe { | |
let mut day : std::os::raw::c_int = 0; | |
let mut sec : std::os::raw::c_int = 0; | |
match ASN1_TIME_diff(&mut day, &mut sec, x.as_ptr(), y.as_ptr() ) { | |
0 => Err(ConnectionError::InvalidTimeFormat), | |
_ => Ok(day > 0 || sec > 0) | |
} | |
} | |
} | |
fn is_valid_time(peer: &openssl::x509::X509) -> Result<(), ConnectionError> { | |
let now = openssl::asn1::Asn1Time::days_from_now(0)?; | |
if is_before(&now, peer.not_before())? { | |
return Err(ConnectionError::ClientCertNotYetValid { date: peer.not_before().to_string() }); | |
} | |
if is_before(peer.not_after(), &now)? { | |
return Err(ConnectionError::ClientCertExpired { date: peer.not_after().to_string() } ); | |
} | |
Ok(()) | |
} | |
match ssl.peer_certificate() { | |
None => { | |
return Ok(Some("ERR No certificate was provided\r\n".to_string())); | |
} | |
Some(peer) => { | |
let thumbprint = hex::encode(peer.digest(openssl::hash::MessageDigest::sha1())?); | |
if server.allowed_certs_thumbprints.contains(&thumbprint) == false { | |
let msg = format!("ERR certificate ({}) thumbprint '{}' is unknown\r\n", | |
get_friendly_name(&peer), | |
thumbprint); | |
return Ok(Some(msg)); | |
} | |
if let Err(e) = is_valid_time(&peer) { | |
let msg = format!("ERR certificate ({}) thumbprint '{}' cannot be used: {}\r\n", | |
get_friendly_name(&peer), | |
thumbprint, | |
e); | |
return Ok(Some(msg)); | |
} | |
} | |
}; | |
return Ok(None); | |
} | |
fn main() -> Result<(), ConnectionError> { | |
let server = Server::new( | |
"C:\\Work\\temp\\example-com.cert.pem", | |
"C:\\Work\\temp\\example-com.key.pem", | |
// allowed thumprints | |
&["1776821db1002b0e2a9b4ee3d5ee14133d367009"] | |
)?; | |
let listener = TcpListener::bind(&"127.0.0.1:4888".parse::<std::net::SocketAddr>()?)?; | |
println!("Started"); | |
let accept_connections = listener.incoming() | |
.for_each(move |socket|{ | |
let write = write_all(socket, b"OK\r\n"); | |
tokio::spawn(write.map(|_|()).map_err(|_|())) | |
}); | |
let accept_connections = listener.incoming() | |
.map_err(|e| eprintln!("Failed to accept {}", e) ) | |
.for_each(move |socket| { | |
let handshake =server.tls_config.accept_async(socket); | |
let auth = handshake.map(|stream|{ | |
let (reader, writer) = stream.split(); | |
let msg = lines(BufReader::new(reader)) | |
.fold(String::new(), |mut buffered, line| { | |
match line.len() { | |
0 if buffered.len() == 0 => Ok(buffered), // should we error? | |
0 => { | |
println!("{}", buffered); | |
//let cmd = parse_cmd(&buffered); | |
//dispatch here | |
Ok(String::new()) | |
}, | |
line_len if buffered.len() + line_len > 8192 => { | |
Err(io::Error::new(io::ErrorKind::Other, ConnectionError::MessageTooBig)) | |
}, | |
_ => { | |
buffered.push_str(&line); | |
Ok(buffered) | |
} | |
} | |
}) | |
.map_err(|e| ()) | |
.map(|_| ()); | |
tokio::spawn(msg) | |
}); | |
let connection = auth.map_err(|e| ()).map(|_| ()); | |
tokio::spawn(connection) | |
}); | |
tokio::run(accept_connections); | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment