Created
January 26, 2021 21:30
-
-
Save benaubin/698f7a59fd2f8a8db72167ed13f4cf6f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{borrow::Cow, sync::{Arc, Weak}}; | |
use std::io::{Read, Write, IoSlice, IoSliceMut}; | |
const FRAME_BUFFER_MAX_LEN: usize = 8 * 1024; | |
pub struct FrameChannel { | |
queue: crossbeam::queue::SegQueue<Vec<u8>>, | |
waker: mio::Waker | |
} | |
impl FrameChannel { | |
pub fn new<T>(stream: T, waker: mio::Waker) -> (Arc<FrameChannel>, FramedIO<T>) { | |
let channel = Arc::new(FrameChannel { | |
queue: crossbeam::queue::SegQueue::new(), | |
waker | |
}); | |
( | |
channel, | |
FramedIO { | |
stream, | |
channel: Arc::downgrade(channel), | |
out_in_progress_frame: None, | |
out_bytes_written: 0, | |
read_state: FrameReadState::HeaderRead { header_buf: [0; 4], header_bytes_read: 0 } | |
} | |
) | |
} | |
pub fn send(&self, frame: Vec<u8>) { | |
self.queue.push(frame); | |
self.waker.wake().expect("failed to wake writer"); | |
} | |
} | |
impl Drop for FrameChannel { | |
fn drop(&mut self) { | |
self.waker.wake().expect("failed to wake writer") | |
} | |
} | |
struct FramedIO<T> { | |
stream: T, | |
channel: Weak<FrameChannel>, | |
out_bytes_written: usize, | |
out_in_progress_frame: Option<Vec<u8>>, | |
read_state: FrameReadState | |
} | |
enum FrameReadState { | |
HeaderRead { | |
header_buf: [u8; 4], | |
header_bytes_read: usize | |
}, | |
BufferedFrameRead { | |
buf: [u8; 8 * 1024], | |
frame_start: usize, | |
frame_bytes_read: usize, | |
frame_len: usize, | |
}, | |
LargeFrameRead { | |
frame: Vec<u8>, | |
frame_bytes_read: usize | |
} | |
} | |
impl<T: Read> FramedIO<T> { | |
/// Attempt to read the next frame from the stream. Returns Ok(None) if no more bytes are available | |
fn read(&mut self) -> std::io::Result<Option<Cow<[u8]>>> { | |
loop { | |
match self.read_state { | |
FrameReadState::HeaderRead { header_buf, ref mut header_bytes_read } => { | |
let mut overflow_buf = [0; 8 * 1024]; | |
while 4 > header_bytes_read { | |
let bytes_read = self.stream.read_vectored(&mut [ | |
IoSliceMut::new(&mut header_buf[header_bytes_read..]), | |
IoSliceMut::new(&mut overflow_buf) | |
])?; | |
if bytes_read == 0 { return Ok(None) } | |
header_bytes_read += bytes_read; // no bytes available right now | |
} | |
self.read_state = FrameReadState::BufferedFrameRead { | |
buf: overflow_buf, | |
frame_start: 0, | |
frame_bytes_read: header_bytes_read - 4, | |
frame_len: u32::from_le_bytes(header_buf) | |
}; | |
} | |
FrameReadState::BufferedFrameRead { | |
ref mut buf, | |
frame_bytes_read, | |
frame_len, | |
.. | |
} if frame_len > 8 * 1024 => { | |
// the maximum buffer size is too small to read this frame into. allocate a vector for the frame and copy the bytes into it | |
let mut frame = vec! [ 0; frame_len ]; | |
frame[..frame_bytes_read] = buf[..frame_bytes_read]; | |
self.read_state = FrameReadState::LargeFrameRead { | |
frame, | |
frame_bytes_read | |
}; | |
} | |
FrameReadState::BufferedFrameRead { | |
ref mut buf, | |
ref mut frame_start, | |
frame_bytes_read, | |
frame_len | |
} if frame_start + frame_len > buf.len() => { | |
// the current buffer is too small to read the frame into. | |
// create a new buffer so that we don't have to heap allocate a vector to hold this frame | |
let next_buf = [0; 8 * 1024]; | |
next_buf[..frame_bytes_read] = buf[(*frame_start)..(*frame_start + frame_bytes_read)]; | |
buf = next_buf; | |
frame_start = 0; | |
} | |
FrameReadState::BufferedFrameRead { | |
ref mut buf, | |
ref mut frame_start, | |
ref mut frame_bytes_read, | |
ref mut frame_len | |
} => { | |
// we only care about at the buffer after the start position of the frame | |
let buf = &mut buf[*frame_start..]; | |
// read from the stream until we've read the frame (note that we may already have read the next frame) | |
while frame_len > frame_bytes_read { | |
let bytes_read = self.stream.read(buf[frame_bytes_read..])?; | |
if bytes_read == 0 { return Ok(None) } // no bytes to read at the moment | |
frame_bytes_read += bytes_read; | |
} | |
let (frame_buf, overflow_buf) = (&mut buf[..*frame_bytes_read]).split_at_mut(frame_len); | |
self.read_state = if overflow_buf.len() > 4 { | |
// we've read the current frame and the header for the next frame | |
frame_start += frame_len + 4; | |
frame_bytes_read -= frame_len + 4; | |
frame_len = u32::from_le_bytes(overflow_buf[..4]); | |
} else { | |
// we've read the current frame, but not the complete header for the next frame | |
let mut header_buf = [0; 4]; | |
// calculate the number of header bytes read | |
let header_bytes_read = *frame_bytes_read - frame_len; | |
// copy the header bytes from the overflow buffer | |
header_buf[..header_bytes_read] = overflow_buf[..header_bytes_read]; | |
FrameReadState::HeaderRead { | |
header_buf, | |
header_bytes_read | |
}; | |
}; | |
return Ok(Some(Cow::Borrowed(frame_buf))); | |
} | |
FrameReadState::LargeFrameRead { mut frame, ref mut frame_bytes_read } => { | |
let mut header_buf = [0; 4]; | |
let mut overflow_buf = [0; 8 * 1024]; | |
while frame.len() > frame_bytes_read { | |
let unread_buf = &mut frame[*frame_bytes_read..]; | |
let read_bytes = self.stream.read_vectored(&mut [ | |
IoSliceMut::new(unread_buf), | |
IoSliceMut::new(header_buf), | |
IoSliceMut::new(overflow_buf) | |
])?; | |
if read_bytes == 0 { return Ok(None) } // no bytes to read at the moment | |
frame_bytes_read += read_bytes; | |
} | |
let overflow_bytes_read = frame_bytes_read - frame.len(); | |
self.read_state = if overflow_bytes_read >= 4 { | |
FrameReadState::BufferedFrameRead { | |
buf: overflow_buf, | |
frame_start: 0, | |
frame_bytes_read: overflow_bytes_read - 4, | |
frame_len: u32::from_le_bytes(header_buf) | |
} | |
} else { | |
FrameReadState::HeaderRead { | |
header_buf, | |
header_bytes_read: overflow_bytes_read | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
pub enum WriteStatus { | |
/// The stream is out of write capacity | |
WantsWrite, | |
/// Out of frames to write | |
WantsFrames, | |
/// The frames channel has disconnected, and no more frames will be sent. | |
Disconnected | |
} | |
impl<T: Write> FramedIO<T> { | |
/// Attempt to write the next frame into the stream. Returns Ok(None) if no frames available | |
fn write(&mut self) -> std::io::Result<WriteStatus> { | |
loop { | |
match self.out_in_progress_frame { | |
Some(ref frame) => { | |
let bufs = | |
&[ | |
IoSlice::new(&(frame.len() as u32).to_le_bytes()[self.out_bytes_written.max(4)..]), | |
IoSlice::new(frame[(self.out_bytes_written).min(4) - 4..]) | |
]; | |
match self.stream.write_vectored(bufs)? { | |
0 => return Ok(WriteStatus::WantsWrite), // can't write anymore | |
written => self.out_bytes_written += written | |
}; | |
if self.out_bytes_written == frame.len() + 4 { | |
self.out_in_progress_frame = None; // finished writing | |
self.out_bytes_written -= frame.len() + 4; | |
} | |
} | |
None => { | |
match self.channel.upgrade().map(|src| src.queue.pop() ) { | |
Some(Some(frame)) => { | |
self.out_in_progress_frame = Some(frame); | |
}, | |
Some(None) => return Ok(WriteStatus::WantsFrames), // no more frames to write | |
None => return Ok(WriteStatus::Disconnected) | |
} | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment