Skip to content

Instantly share code, notes, and snippets.

@caibear
Last active September 17, 2024 17:40
Show Gist options
  • Save caibear/68a9faf3f9a4ce94f321c7e06d2ea0ca to your computer and use it in GitHub Desktop.
Save caibear/68a9faf3f9a4ce94f321c7e06d2ea0ca to your computer and use it in GitHub Desktop.
bitcode packet packer
[package]
name = "bitcode_packet_packer"
version = "0.1.0"
edition = "2021"
[dependencies]
bitcode = "0.6.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
lz4_flex = { version = "0.11.2", default-features = false }
use std::time::Instant;
use rand::prelude::*;
use rand::distributions::{Standard};
#[derive(bitcode::Encode, bitcode::Decode)]
enum Message {
A(bool),
B(u8),
C {
name: String,
x: i16,
y: i16,
}
}
impl Distribution<Message> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Message {
if rng.gen_bool(0.5) {
Message::A(rng.gen_bool(0.1))
} else if rng.gen_bool(0.4) {
Message::B(rng.gen_range(9..15))
} else {
Message::C {
name: if rng.gen_bool(0.0001) {
// Throw a curveball of an incompressible string larger than a single packet.
let n = rng.gen_range(1300..2000);
(0..n).map(|_| rng.gen_range(b'0'..=b'9') as char).collect()
} else {
["cow", "sheep", "zombie", "skeleton", "spider", "creeper", "parrot", "bee"].choose(rng).unwrap().to_string()
},
x: rng.gen_range(-100..100),
y: rng.gen_range(0..15),
}
}
}
}
struct Packet(Vec<u8>);
impl Packet {
const MAX_SIZE: usize = 1200;
}
fn main() {
println!("\n\npack_naive\n");
run_packer(pack_naive);
println!("\n\npack_appended\n");
run_packer(pack_appended);
println!("\n\npack_exponential_search\n");
run_packer(pack_exponential_search);
println!("\n\npack_interpolation_search\n");
run_packer(pack_interpolation_search);
println!("\n\pack_multiple_exponential\n");
run_packer(pack_multiple_exponential);
}
fn run_packer(mut p: impl FnMut(&[Message]) -> Vec<Packet>) {
let mut rng = rand_chacha::ChaCha20Rng::from_seed(Default::default());
let mut total_bytes = 0;
let mut packet_count = 0;
let mut packet_extensions = 0;
let start = Instant::now();
// Run 10 time steps.
for _ in 0..10 {
let n = rng.gen_range(200..20000);
let messages: Vec<Message> = (0..n).map(|_| rng.gen()).collect();
let packets = p(&messages);
let packet_lens: Vec<_> = packets.iter().map(|p| p.0.len()).collect();
total_bytes += packet_lens.iter().sum::<usize>();
packet_count += packets.len();
packet_extensions += packet_lens.iter().map(|&len| len.saturating_sub(1) / Packet::MAX_SIZE).sum::<usize>();
println!("{packet_lens:?}");
}
let elapsed = start.elapsed();
println!("\n{total_bytes} bytes total, {packet_count} packets, {packet_extensions} extension packets, took {elapsed:?}");
}
fn encode_compressed(t: &(impl bitcode::Encode + ?Sized)) -> Vec<u8> {
let encoded = bitcode::encode(t);
// Makes pack_interpolation_search take 33% fewer packets without reducing speed at all.
const COMPRESS: bool = true;
if COMPRESS {
lz4_flex::compress_prepend_size(&encoded)
} else {
encoded
}
}
fn pack_naive(messages: &[Message]) -> Vec<Packet> {
vec![Packet(encode_compressed(messages))]
}
fn pack_appended(messages: &[Message]) -> Vec<Packet> {
let mut bytes = vec![];
let mut packets = vec![];
for m in messages {
// Don't use encode_compressed since compression doesn't improve tiny messages.
let encoded = bitcode::encode(m);
if bytes.len() + encoded.len() > Packet::MAX_SIZE {
packets.push(Packet(std::mem::take(&mut bytes)));
}
bytes.extend_from_slice(&encoded);
}
if !bytes.is_empty() {
packets.push(Packet(bytes));
}
packets
}
fn pack_exponential_search(mut messages: &[Message]) -> Vec<Packet> {
let mut packets = vec![];
let mut n = 1;
let mut last = None;
loop {
n = n.min(messages.len());
let chunk = &messages[..n];
let encoded = encode_compressed(chunk);
let current = (encoded, n);
if current.0.len() < Packet::MAX_SIZE && n < messages.len() {
last = Some(current);
n *= 2;
continue;
}
n = 1;
// If the current chunk is too big, use the last chunk.
let (encoded, n) = last.take().filter(|_| current.0.len() > Packet::MAX_SIZE).unwrap_or(current);
messages = &messages[n..];
packets.push(Packet(encoded));
if messages.is_empty() {
break;
}
}
packets
}
fn pack_interpolation_search(mut messages: &[Message]) -> Vec<Packet> {
const SAMPLE: usize = 32; // Tune based on expected message size and variance.
const PRECISION: usize = 30; // More precision will take longer, but get closer to max packet size.
const MAX_ATTEMPTS: usize = 4; // Maximum number of attempts before giving up.
const TARGET_SIZE: usize = Packet::MAX_SIZE * PRECISION / (PRECISION + 1);
const MIN_SIZE: usize = TARGET_SIZE * PRECISION / (PRECISION + 1);
const DEBUG: bool = false;
let mut packets = vec![];
let mut message_size = None;
// If we run out of attempts, send the largest attempt so far to avoid infinite loop.
let mut attempts = 0;
let mut largest_so_far = None;
while !messages.is_empty() {
let n = message_size.map(|message_size: f32| {
(TARGET_SIZE as f32 / message_size).floor() as usize
}).unwrap_or(SAMPLE);
let n = n.clamp(1, messages.len());
let chunk = &messages[..n];
let encoded = encode_compressed(chunk);
message_size = Some(encoded.len() as f32 / n as f32);
let too_large = encoded.len() > Packet::MAX_SIZE;
let too_small = encoded.len() < MIN_SIZE && n != messages.len();
let current = (encoded, n);
let (encoded, n) = if too_large || too_small {
if attempts < MAX_ATTEMPTS {
if DEBUG {
println!("skipping {n} messages with {} bytes", current.0.len());
}
if too_small && n > largest_so_far.as_ref().map_or(0, |(_, n)| *n) {
largest_so_far = Some(current);
}
attempts += 1;
continue;
}
// We ran out of attempts, if the current chunk is too big, use the largest chunk so far.
largest_so_far.take().filter(|_| current.0.len() > Packet::MAX_SIZE).unwrap_or(current)
} else {
current
};
attempts = 0;
largest_so_far = None;
if DEBUG {
println!("packed {n} messages with {} bytes", encoded.len());
}
messages = &messages[n..];
packets.push(Packet(encoded));
}
// TODO merge tiny packets (caused by single messages > Packet::MAX_SIZE)
packets
}
/// NEW
/// Uses multiple exponential searches to fill a packet. Has a good worst case runtime and doesn't
/// create any extraneous extension packets.
fn pack_multiple_exponential(mut messages: &[Message]) -> Vec<Packet> {
/// A Vec<u8> prefixed by its length as a u32. Each [`Packet`] contains 1 or more [`Section`]s.
struct Section(Vec<u8>);
impl Section {
fn len(&self) -> usize {
self.0.len() + std::mem::size_of::<u32>()
}
fn write(&self, out: &mut Vec<u8>) {
out.reserve(self.len());
out.extend_from_slice(&u32::try_from(self.0.len()).unwrap().to_le_bytes()); // TODO use varint.
out.extend_from_slice(&self.0);
}
}
let mut buffer = bitcode::Buffer::new(); // TODO save between calls.
let mut packets = vec![];
while !messages.is_empty() {
let mut remaining = Packet::MAX_SIZE;
let mut bytes = vec![];
while remaining > 0 && !messages.is_empty() {
let mut i = 0;
let mut previous = None;
loop {
i = (i * 2).clamp(1, messages.len());
const COMPRESS: bool = true;
let b = Section(if COMPRESS {
lz4_flex::compress_prepend_size(&buffer.encode(&messages[..i]))
} else {
buffer.encode(&messages[..i]).to_vec()
});
let (i, b) = if b.len() <= remaining {
if i == messages.len() {
// No more messages.
(i, b)
} else {
// Try to fit more.
previous = Some((i, b));
continue;
}
} else if let Some((i, b)) = previous {
// Current failed, so use previous.
(i, b)
} else {
assert_eq!(i, 1);
// 1 message doesn't fit. If starting a new packet would result in fewer
// fragments, flush the current packet.
let flush_fragments = b.len().div_ceil(Packet::MAX_SIZE) - 1;
let keep_fragments = (b.len() - remaining).div_ceil(Packet::MAX_SIZE);
if flush_fragments < keep_fragments {
// TODO try to fill current packet by with packets after the single large packet.
packets.push(Packet(std::mem::take(&mut bytes)));
remaining = Packet::MAX_SIZE;
}
(i, b)
};
messages = &messages[i..];
if bytes.is_empty() && b.len() < Packet::MAX_SIZE {
bytes = Vec::with_capacity(Packet::MAX_SIZE); // Assume we'll fill the packet.
}
b.write(&mut bytes);
if b.len() > remaining {
assert_eq!(i, 1);
// TODO fill extension packets. We would need to know where the section ends
// within the packet in case previous packets are lost.
remaining = 0;
} else {
remaining -= b.len();
}
break;
}
}
packets.push(Packet(bytes));
}
packets
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment