Created
August 28, 2015 21:12
-
-
Save XMPPwocky/ef521cea192b98fa5210 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp::min; | |
use self::BitBufError::*; | |
pub struct BitBuf { | |
contents: Vec<u8>, | |
current_bit: u32, | |
bit_count: u32 | |
} | |
#[derive(Debug, Copy, Clone, PartialEq)] | |
pub enum BitBufError { | |
OutOfRoom, | |
BadNumberOfBits | |
} | |
pub type BitBufResult<T> = Result<T, BitBufError>; | |
impl BitBuf { | |
pub fn new(contents: Vec<u8>, bit_count: u32) -> BitBuf { | |
assert!((bit_count / 8) as usize <= contents.len()); | |
BitBuf { | |
contents: contents, | |
current_bit: 0, | |
bit_count: bit_count | |
} | |
} | |
pub fn empty() -> BitBuf { | |
BitBuf { | |
contents: Vec::new(), | |
current_bit: 0, | |
bit_count: 0 | |
} | |
} | |
pub fn tell(&self) -> u32 { | |
self.current_bit | |
} | |
pub fn seek(&mut self, pos: u32) { | |
self.current_bit = pos; | |
} | |
pub fn read_bits_as_u32(&mut self, bit_count: u32) -> BitBufResult<u32> { | |
if bit_count > 32 { | |
return Err(BadNumberOfBits); | |
} | |
let startbit = self.current_bit; | |
let endbit = self.current_bit + bit_count; | |
if endbit > self.bit_count { | |
return Err(BitBufError::OutOfRoom); | |
} | |
let startbyte = startbit / 8; | |
let endbyte = endbit / 8 ; | |
debug_assert!(endbyte as usize <= self.contents.len()); | |
let mut val = 0; | |
let mut bits_read = 0; | |
let mut skipbits = startbit % 8; | |
for i in startbyte..endbyte + 1 { | |
let byte = self.contents[i as usize] as u32; | |
let numbitsskipped = min(skipbits, 8); | |
skipbits -= numbitsskipped; | |
let truncatebit = if i == endbyte { | |
endbit % 8 | |
} else { | |
8 | |
}; | |
let maskedbyte = byte & make_bitmask(numbitsskipped, truncatebit); | |
let shiftedbyte = maskedbyte >> numbitsskipped; | |
val = val | (shiftedbyte << bits_read); | |
bits_read += truncatebit - numbitsskipped; | |
} | |
self.current_bit += bits_read; | |
debug_assert_eq!(bits_read, bit_count); | |
debug_assert!(self.current_bit <= self.bit_count); | |
Ok(val) | |
} | |
pub fn write_u32_as_bits(&mut self, val: u32, bit_count: u32) -> BitBufResult<()> { | |
if bit_count > 32 { | |
return Err(BitBufError::BadNumberOfBits); | |
} | |
let startbit = self.current_bit; | |
let endbit = self.current_bit + bit_count; | |
let startbyte = startbit / 8; | |
let endbyte = endbit / 8 ; | |
let padbytes = endbyte as usize + 1 - self.contents.len(); | |
self.contents.extend(::std::iter::repeat(0) | |
.take(padbytes)); | |
self.bit_count += padbytes as u32 * 8; | |
let mut bits_written = 0; | |
let mut skipbits = startbit % 8; | |
for i in startbyte..endbyte + 1 { | |
let numbitsskipped = min(skipbits, 8); | |
skipbits -= numbitsskipped; | |
let truncatebit = if i == endbyte { | |
endbit % 8 | |
} else { | |
8 | |
}; | |
let maskedbyte = ((val >> bits_written) << numbitsskipped) & make_bitmask(numbitsskipped, truncatebit); | |
self.contents[i as usize] = maskedbyte as u8; | |
bits_written += truncatebit - numbitsskipped; | |
} | |
self.current_bit += bits_written; | |
debug_assert_eq!(bits_written, bit_count); | |
debug_assert!(self.current_bit <= self.bit_count); | |
Ok(()) | |
} | |
pub fn read_bits_as_i32(&mut self, bit_count: u32) -> BitBufResult<i32> { | |
let val = try!(self.read_bits_as_u32(bit_count)) as i32; | |
let max_neg = 1 << (bit_count - 1); | |
if val > max_neg { | |
// it's negative (two's complement); fix sign | |
Ok(val - (2 * max_neg)) | |
} else { | |
Ok(val) | |
} | |
} | |
pub fn write_i32_as_bits(&mut self, val: i32, bit_count: u32) -> BitBufResult<()> { | |
// FIXME: there should be a check for val being too big here | |
self.write_u32_as_bits(val as u32, bit_count) | |
} | |
pub fn bytes(&self) -> &[u8] { | |
&self.contents[0..(self.bit_count / 8) as usize] | |
} | |
} | |
fn make_bitmask(startbit: u32, endbit: u32) -> u32 { | |
let mask = 0xFFFFFFFF << startbit; // mask LSBs | |
mask & !(0xFFFFFFFF << endbit) // mask MSBs | |
} | |
#[test] | |
fn smoke_write() { | |
let mut buf = BitBuf::empty(); | |
buf.write_u32_as_bits(89, 19).unwrap(); | |
buf.write_u32_as_bits(42, 12).unwrap(); | |
assert_eq!(buf.bytes(), [0x59, 0x00, 0x50, 0x01]); | |
} | |
#[test] | |
fn smoke_read() { | |
let contents = vec![0x59, 0x00, 0x50, 0x01]; | |
let mut buf = BitBuf::new(contents, 31); | |
assert_eq!(buf.read_bits_as_u32(19), Ok(89)); | |
assert_eq!(buf.read_bits_as_u32(12), Ok(42)); | |
} | |
#[test] | |
fn roundtrip() { | |
let mut buf = BitBuf::empty(); | |
buf.write_u32_as_bits(532, 23).unwrap(); | |
buf.write_u32_as_bits(1, 1).unwrap(); | |
buf.write_i32_as_bits(-98, 8).unwrap(); | |
buf.write_u32_as_bits(9, 6).unwrap(); | |
buf.seek(0); | |
assert_eq!(buf.read_bits_as_u32(23), Ok(532)); | |
assert_eq!(buf.read_bits_as_u32(1), Ok(1)); | |
assert_eq!(buf.read_bits_as_i32(8), Ok(-98)); | |
assert_eq!(buf.read_bits_as_u32(6), Ok(9)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment