Last active
September 27, 2024 15:30
-
-
Save a10y/5c073c0777e3625050ab12e5fa5fd53c to your computer and use it in GitHub Desktop.
Rust bit-packing/unpacking for u8/u3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub fn pack_u8_u3(input: &[u8], packed: &mut [u8]) { | |
// We have 1024 / size_of<T>() == 128 lanes to pull from. | |
// Each lane accesses 1024 / T elements of data. | |
const MASK: u8 = 0b111; | |
const LANES: usize = 1024 / 8; | |
for lane in 0..LANES { | |
// First kernel: take in chunks of W values from the lane, and apply the same | |
// operation. Being careful to shift off each time. | |
let a = input[128 * 0 + lane] & MASK; | |
let b = input[128 * 1 + lane] & MASK; | |
let c = input[128 * 2 + lane] & MASK; | |
packed[128 * 0 + lane] = a | (b << 3) | (c << 6); | |
// next one. | |
// First: carry over the first bit from c | |
let carry = c >> 2; | |
let d = input[128 * 3 + lane] & MASK; | |
let e = input[128 * 4 + lane] & MASK; | |
let f = input[128 * 5 + lane] & MASK; | |
packed[128 * 1 + lane] = carry | (d << 1) | (e << 4) | (f << 7); | |
// final one | |
let carry = f >> 1; | |
let g = input[128 * 6 + lane] & MASK; | |
let h = input[128 * 7 + lane] & MASK; | |
packed[128 * 2 + lane] = carry | (g << 2) | (h << 5); | |
} | |
} | |
// Unpack operation: take each of the values and project back out into their original rows. | |
pub fn unpack_u3_u8(input: &[u8], unpacked: &mut [u8]) { | |
const LANES: usize = 1024 / 8; | |
const MASK: u8 = 0b111; | |
for lane in 0..LANES { | |
// value0 = a | (b << 3) | (c << 6) | |
let value0 = input[128 * 0 + lane]; | |
// value1 = (c >> 2) | (d << 1) | (e << 4) | (f << 7) | |
let value1 = input[128 * 1 + lane]; | |
let value2 = input[128 * 2 + lane]; | |
unpacked[128 * 0 + lane] = value0 & MASK; | |
unpacked[128 * 1 + lane] = (value0 >> 3) & MASK; | |
unpacked[128 * 2 + lane] = (value0 >> 6) & MASK | (value1 << 2) & MASK; | |
unpacked[128 * 3 + lane] = (value1 >> 1) & MASK; | |
unpacked[128 * 4 + lane] = (value1 >> 4) & MASK; | |
unpacked[128 * 5 + lane] = (value1 >> 7) | (value2 << 1) & MASK; | |
unpacked[128 * 6 + lane] = (value2 >> 2) & MASK; | |
unpacked[128 * 7 + lane] = (value2 >> 5) & MASK; | |
} | |
} | |
#[cfg(test)] | |
mod test { | |
use super::{pack_u8_u3, unpack_u3_u8}; | |
#[test] | |
fn test_roundtrip() { | |
let values: Vec<u8> = (0u8..8u8).into_iter().cycle().take(1024).collect(); | |
let mut packed: Vec<u8> = vec![0; 384]; | |
let mut unpacked: Vec<u8> = vec![0; 1024]; | |
pack_u8_u3(&values, &mut packed); | |
unpack_u3_u8(&packed, &mut unpacked); | |
assert_eq!(values, unpacked); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment