Skip to content

Instantly share code, notes, and snippets.

@a10y
Last active September 27, 2024 15:30
Show Gist options
  • Save a10y/5c073c0777e3625050ab12e5fa5fd53c to your computer and use it in GitHub Desktop.
Save a10y/5c073c0777e3625050ab12e5fa5fd53c to your computer and use it in GitHub Desktop.
Rust bit-packing/unpacking for u8/u3
pub fn pack_u8_u3(input: &[u8], packed: &mut [u8]) {
// We have 1024 / size_of<T>() == 128 lanes to pull from.
// Each lane accesses 1024 / T elements of data.
const MASK: u8 = 0b111;
const LANES: usize = 1024 / 8;
for lane in 0..LANES {
// First kernel: take in chunks of W values from the lane, and apply the same
// operation. Being careful to shift off each time.
let a = input[128 * 0 + lane] & MASK;
let b = input[128 * 1 + lane] & MASK;
let c = input[128 * 2 + lane] & MASK;
packed[128 * 0 + lane] = a | (b << 3) | (c << 6);
// next one.
// First: carry over the first bit from c
let carry = c >> 2;
let d = input[128 * 3 + lane] & MASK;
let e = input[128 * 4 + lane] & MASK;
let f = input[128 * 5 + lane] & MASK;
packed[128 * 1 + lane] = carry | (d << 1) | (e << 4) | (f << 7);
// final one
let carry = f >> 1;
let g = input[128 * 6 + lane] & MASK;
let h = input[128 * 7 + lane] & MASK;
packed[128 * 2 + lane] = carry | (g << 2) | (h << 5);
}
}
// Unpack operation: take each of the values and project back out into their original rows.
pub fn unpack_u3_u8(input: &[u8], unpacked: &mut [u8]) {
const LANES: usize = 1024 / 8;
const MASK: u8 = 0b111;
for lane in 0..LANES {
// value0 = a | (b << 3) | (c << 6)
let value0 = input[128 * 0 + lane];
// value1 = (c >> 2) | (d << 1) | (e << 4) | (f << 7)
let value1 = input[128 * 1 + lane];
let value2 = input[128 * 2 + lane];
unpacked[128 * 0 + lane] = value0 & MASK;
unpacked[128 * 1 + lane] = (value0 >> 3) & MASK;
unpacked[128 * 2 + lane] = (value0 >> 6) & MASK | (value1 << 2) & MASK;
unpacked[128 * 3 + lane] = (value1 >> 1) & MASK;
unpacked[128 * 4 + lane] = (value1 >> 4) & MASK;
unpacked[128 * 5 + lane] = (value1 >> 7) | (value2 << 1) & MASK;
unpacked[128 * 6 + lane] = (value2 >> 2) & MASK;
unpacked[128 * 7 + lane] = (value2 >> 5) & MASK;
}
}
#[cfg(test)]
mod test {
use super::{pack_u8_u3, unpack_u3_u8};
#[test]
fn test_roundtrip() {
let values: Vec<u8> = (0u8..8u8).into_iter().cycle().take(1024).collect();
let mut packed: Vec<u8> = vec![0; 384];
let mut unpacked: Vec<u8> = vec![0; 1024];
pack_u8_u3(&values, &mut packed);
unpack_u3_u8(&packed, &mut unpacked);
assert_eq!(values, unpacked);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment