Skip to content

Instantly share code, notes, and snippets.

@leok7v
Last active October 10, 2024 00:34
Show Gist options
  • Save leok7v/9e09be9319a77a3d7e2da4d780d9905f to your computer and use it in GitHub Desktop.
Save leok7v/9e09be9319a77a3d7e2da4d780d9905f to your computer and use it in GitHub Desktop.
Range coder with simple adaptive frequency model
#ifndef rc_header_included
#define rc_header_included
// Copyright (c) 2024, "Leo" Dmitry Kuznetsov
// https://github.com/leok7v/rc
// This code and the accompanying materials are made available under the terms
// of BSD-3 license, which accompanies this distribution. The full text of the
// license may be found at https://opensource.org/license/bsd-3-clause
#include <stdint.h>
#define rc_sym_bits 8
#define rc_sym_count (1uLL << rc_sym_bits)
#define pm_max_freq (1uLL << (64 - rc_sym_bits))
#define ft_max_bits 31
#define rc_err_io 5 // EIO : I/O error
#define rc_err_too_big 7 // E2BIG : Argument list too long
#define rc_err_no_memory 12 // ENOMEM: Out of memory
#define rc_err_invalid 22 // EINVAL: Invalid argument
#define rc_err_range 34 // ERANGE: Result too large
#define rc_err_data 42 // EILSEQ: Illegal byte sequence
#define rc_err_unsupported 40 // ENOSYS: Functionality not supported
#define rc_err_no_space 55 // ENOBUFS: No buffer space available
struct prob_model { // probability model
uint64_t freq[rc_sym_count];
uint64_t tree[rc_sym_count]; // Fenwick Tree
};
struct range_coder {
uint64_t low;
uint64_t range;
uint64_t code;
void (*write)(struct range_coder*, uint8_t);
uint8_t (*read)(struct range_coder*);
int32_t error;
};
void pm_init(struct prob_model* pm, uint32_t n); // n <= 256
void pm_update(struct prob_model* pm, uint8_t sym, uint64_t inc);
void rc_init(struct range_coder* rc, uint64_t code);
void rc_encode(struct range_coder* rc, struct prob_model* pm, uint8_t sym);
// decoder needs first 8 bytes in the value of code to start
uint8_t rc_decode(struct range_coder* rc, struct prob_model* pm);
#endif // rc_header_included
#ifdef rc_implementation
#include <stdbool.h>
static inline int32_t ft_lsb(int32_t i) { // least significant bit only
return i & (~i + 1); // (i & -i)
}
static void ft_init(uint64_t tree[], size_t n, uint64_t a[]) {
const int32_t m = (int32_t)n;
for (int32_t i = 0; i < m; i++) { tree[i] = a[i]; }
for (int32_t i = 1; i <= m; i++) {
int32_t parent = i + ft_lsb(i);
if (parent <= m) {
tree[parent - 1] += tree[i - 1];
}
}
}
static void ft_update(uint64_t tree[], size_t n, int32_t i, uint64_t inc) {
while (i < (int32_t)n) {
tree[i] += inc;
i += ft_lsb(i + 1);
}
}
static uint64_t ft_query(const uint64_t tree[], size_t n, int32_t i) {
uint64_t sum = 0;
while (i >= 0) {
if (i < (int32_t)n) {
sum += tree[i];
}
i -= ft_lsb(i + 1);
}
return sum;
}
static int32_t ft_index_of(uint64_t tree[], size_t n, uint64_t const sum) {
if (sum >= tree[n - 1]) { return -1; }
uint64_t value = sum;
uint32_t i = 0;
uint32_t mask = (uint32_t)(n >> 1);
while (mask != 0) {
uint32_t t = i + mask;
if (t <= n && value >= tree[t - 1]) {
i = t;
value -= tree[t - 1];
}
mask >>= 1;
}
return i == 0 && value < sum ? -1 : (int32_t)(i - 1);
}
#ifndef countof
#define countof(a) (sizeof(a) / sizeof((a)[0]))
#endif
static inline uint64_t pm_sum_of(struct prob_model* pm, uint32_t sym) {
return ft_query(pm->tree, countof(pm->tree), sym - 1);
}
static inline uint64_t pm_total_freq(struct prob_model* pm) {
return pm->tree[countof(pm->tree) - 1];
}
static inline int32_t pm_index_of(struct prob_model* pm, uint64_t sum) {
return ft_index_of(pm->tree, countof(pm->tree), sum) + 1;
}
void pm_init(struct prob_model* pm, uint32_t n) {
for (size_t i = 0; i < countof(pm->freq); i++) {
pm->freq[i] = i < n ? 1 : 0;
}
ft_init(pm->tree, countof(pm->tree), pm->freq);
}
void pm_update(struct prob_model* pm, uint8_t sym, uint64_t inc) {
if (pm->tree[countof(pm->tree) - 1] < pm_max_freq) {
pm->freq[sym] += inc;
ft_update(pm->tree, countof(pm->tree), sym, inc);
}
}
static void rc_emit(struct range_coder* rc) {
const uint8_t byte = (uint8_t)(rc->low >> 56);
rc->write(rc, byte);
rc->low <<= 8;
rc->range <<= 8;
}
static inline bool rc_leftmost_byte_is_same(struct range_coder* rc) {
return (rc->low >> 56) == ((rc->low + rc->range) >> 56);
}
void rc_init(struct range_coder* rc, uint64_t code) {
rc->low = 0;
rc->range = UINT64_MAX;
rc->code = code;
rc->error = 0;
}
static void rc_flush(struct range_coder* rc) {
for (int i = 0; i < sizeof(rc->low); i++) {
rc->range = UINT64_MAX;
rc_emit(rc);
}
}
static void rc_consume(struct range_coder* rc) {
const uint8_t byte = rc->read(rc);
rc->code = (rc->code << 8) + byte;
rc->low <<= 8;
rc->range <<= 8;
}
void rc_encode(struct range_coder* rc, struct prob_model* pm,
uint8_t sym) {
uint64_t total = pm_total_freq(pm);
uint64_t start = pm_sum_of(pm, sym);
uint64_t size = pm->freq[sym];
rc->range /= total;
rc->low += start * rc->range;
rc->range *= size;
pm_update(pm, sym, 1);
while (rc_leftmost_byte_is_same(rc)) { rc_emit(rc); }
if (rc->range < total + 1) {
rc_emit(rc);
rc_emit(rc);
rc->range = UINT64_MAX - rc->low;
}
}
static uint8_t rc_err(struct range_coder* rc, int32_t e) {
rc->error = e;
return 0;
}
uint8_t rc_decode(struct range_coder* rc, struct prob_model* pm) {
uint64_t total = pm_total_freq(pm);
if (total < 1) { return rc_err(rc, rc_err_invalid); }
if (rc->range < total) {
rc_consume(rc);
rc_consume(rc);
rc->range = UINT64_MAX - rc->low;
}
uint64_t sum = (rc->code - rc->low) / (rc->range / total);
int32_t sym = pm_index_of(pm, sum);
if (sym < 0 || pm->freq[sym] == 0) { return rc_err(rc, rc_err_data); }
uint64_t start = pm_sum_of(pm, sym);
uint64_t size = pm->freq[sym];
if (size == 0 || rc->range < total) { return rc_err(rc, rc_err_data); }
rc->range /= total;
rc->low += start * rc->range;
rc->range *= size;
pm_update(pm, (uint8_t)sym, 1);
while (rc_leftmost_byte_is_same(rc)) { rc_consume(rc); }
return (uint8_t)sym;
}
#endif // rc_implementation
// Copyright (c) 2024, "Leo" Dmitry Kuznetsov
// https://github.com/leok7v/rc
// This code and the accompanying materials are made available under the terms
// of BSD-3 license, which accompanies this distribution. The full text of the
// license may be found at https://opensource.org/license/bsd-3-clause
#include "range_coder.h"
#define rc_implementation
#include "range_coder.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
struct { // in memory io
uint8_t data[1024];
size_t bytes; // number of bytes read by io_read()
size_t written; // number of bytes written by io_write()
uint64_t checksum; // FNV hash
} io;
static void checksum_init(void) {
io.checksum = 0xCBF29CE484222325uLL; // FNV offset basis
}
static void checksum_append(const uint8_t byte) {
io.checksum ^= byte;
io.checksum *= 0x100000001B3; // FNV prime
io.checksum ^= (io.checksum >> 32);
io.checksum = (io.checksum << 7) | (io.checksum >> (64 - 7));
}
static void io_write(struct range_coder* rc, uint8_t b) {
if (rc->error == 0) {
if (io.written < countof(io.data)) {
checksum_append(b);
io.data[io.written++] = b;
} else {
rc->error = rc_err_too_big;
}
}
}
static uint8_t io_read(struct range_coder* rc) {
if (rc->error == 0) {
if (io.bytes >= io.written) {
rc->error = rc_err_io;
} else {
assert(io.bytes < countof(io.data));
checksum_append(io.data[io.bytes]);
return io.data[io.bytes++];
}
}
return 0;
}
static int32_t compare(const uint8_t in[], const uint8_t out[],
size_t n, uint64_t ecs) {
bool equal = ecs == io.checksum;
if (!equal) {
printf("checksum encoder: %016llX != decoder: %016llX\n",
ecs, io.checksum);
} else {
for (size_t i = 0; i < n; i++) {
if (in[i] != out[i]) {
printf("[%d]: %d != %d\n", (int)i, in[i], out[i]);
equal = false;
break;
}
}
}
assert(equal); // break early for debugging
return equal && ecs == io.checksum ? 0 : rc_err_data;
}
static uint64_t encode(struct prob_model* pm, struct range_coder* rc,
const uint8_t a[], size_t n, uint32_t symbols) {
pm_init(pm, symbols);
rc_init(rc, 0);
for (size_t i = 0; i < n && rc->error == 0; i++) {
rc_encode(rc, pm, a[i]);
}
rc_flush(rc);
assert(rc->error == 0);
return io.checksum;
}
static size_t decode(struct prob_model* pm, struct range_coder* rc,
uint8_t a[], size_t n, uint32_t symbols) {
io.bytes = 0;
checksum_init();
pm_init(pm, symbols);
rc->code = 0;
for (size_t i = 0; i < sizeof(rc->code); i++) {
rc->code = (rc->code << 8) + rc->read(rc);
}
rc_init(rc, rc->code);
size_t i = 0;
while (i < n && rc->error == 0) {
uint8_t sym = rc_decode(rc, pm);
a[i++] = (uint8_t)sym;
}
return i;
}
static double entropy(const uint64_t a[], size_t n) {
double total = 0;
for (size_t i = 0; i < n; i++) {
if (a[i] > 1) { total += a[i]; }
}
double e = 0;
for (size_t i = 0; i < n; i++) {
if (a[i] > 1) {
double p = a[i] / total;
e -= p * log2(p);
}
}
return e;
}
int main(void) {
static struct range_coder coder;
static struct prob_model model;
struct range_coder* rc = &coder;
struct prob_model* pm = &model;
static const char text[] =
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, "
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. "
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris "
"nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in "
"reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla "
"pariatur. Excepteur sint occaecat cupidatat non proident, sunt in "
"culpa qui officia deserunt mollit anim id est laborum.";
enum { bits = 8 };
enum { symbols = 1 << bits };
enum { n = countof(text) - 1 }; // last byte text[countof(text)] == 0
io.bytes = 0;
io.written = 0;
rc->write = io_write;
rc->read = io_read;
checksum_init();
const uint8_t* in = (const uint8_t*)text;
uint64_t ecs = encode(pm, rc, in, n, symbols);
const double e = entropy(pm->freq, symbols);
const double bps = io.written * 8.0 / n;
const double percent = 100.0 * io.written * 8 / ((int64_t)n * bits);
printf("%lld to %lld bytes. %.1f%% bps: %.3f Shannon H: %.3f\n",
((uint64_t)n * bits / 8), (uint64_t)io.written, percent, bps, e);
uint8_t out[n];
size_t k = decode(pm, rc, out, n, symbols);
assert(rc->error == 0 && k == n && ecs == io.checksum);
int32_t r = compare(in, out, n, ecs);
printf("decode(): %s\n", r == 0 ? "ok" : "failed");
assert(r == 0);
return r;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment