Last active
October 10, 2024 00:34
-
-
Save leok7v/9e09be9319a77a3d7e2da4d780d9905f to your computer and use it in GitHub Desktop.
Range coder with simple adaptive frequency model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef rc_header_included | |
#define rc_header_included | |
// Copyright (c) 2024, "Leo" Dmitry Kuznetsov | |
// https://github.com/leok7v/rc | |
// This code and the accompanying materials are made available under the terms | |
// of BSD-3 license, which accompanies this distribution. The full text of the | |
// license may be found at https://opensource.org/license/bsd-3-clause | |
#include <stdint.h> | |
#define rc_sym_bits 8 | |
#define rc_sym_count (1uLL << rc_sym_bits) | |
#define pm_max_freq (1uLL << (64 - rc_sym_bits)) | |
#define ft_max_bits 31 | |
#define rc_err_io 5 // EIO : I/O error | |
#define rc_err_too_big 7 // E2BIG : Argument list too long | |
#define rc_err_no_memory 12 // ENOMEM: Out of memory | |
#define rc_err_invalid 22 // EINVAL: Invalid argument | |
#define rc_err_range 34 // ERANGE: Result too large | |
#define rc_err_data 42 // EILSEQ: Illegal byte sequence | |
#define rc_err_unsupported 40 // ENOSYS: Functionality not supported | |
#define rc_err_no_space 55 // ENOBUFS: No buffer space available | |
struct prob_model { // probability model | |
uint64_t freq[rc_sym_count]; | |
uint64_t tree[rc_sym_count]; // Fenwick Tree | |
}; | |
struct range_coder { | |
uint64_t low; | |
uint64_t range; | |
uint64_t code; | |
void (*write)(struct range_coder*, uint8_t); | |
uint8_t (*read)(struct range_coder*); | |
int32_t error; | |
}; | |
void pm_init(struct prob_model* pm, uint32_t n); // n <= 256 | |
void pm_update(struct prob_model* pm, uint8_t sym, uint64_t inc); | |
void rc_init(struct range_coder* rc, uint64_t code); | |
void rc_encode(struct range_coder* rc, struct prob_model* pm, uint8_t sym); | |
// decoder needs first 8 bytes in the value of code to start | |
uint8_t rc_decode(struct range_coder* rc, struct prob_model* pm); | |
#endif // rc_header_included | |
#ifdef rc_implementation | |
#include <stdbool.h> | |
static inline int32_t ft_lsb(int32_t i) { // least significant bit only | |
return i & (~i + 1); // (i & -i) | |
} | |
static void ft_init(uint64_t tree[], size_t n, uint64_t a[]) { | |
const int32_t m = (int32_t)n; | |
for (int32_t i = 0; i < m; i++) { tree[i] = a[i]; } | |
for (int32_t i = 1; i <= m; i++) { | |
int32_t parent = i + ft_lsb(i); | |
if (parent <= m) { | |
tree[parent - 1] += tree[i - 1]; | |
} | |
} | |
} | |
static void ft_update(uint64_t tree[], size_t n, int32_t i, uint64_t inc) { | |
while (i < (int32_t)n) { | |
tree[i] += inc; | |
i += ft_lsb(i + 1); | |
} | |
} | |
static uint64_t ft_query(const uint64_t tree[], size_t n, int32_t i) { | |
uint64_t sum = 0; | |
while (i >= 0) { | |
if (i < (int32_t)n) { | |
sum += tree[i]; | |
} | |
i -= ft_lsb(i + 1); | |
} | |
return sum; | |
} | |
static int32_t ft_index_of(uint64_t tree[], size_t n, uint64_t const sum) { | |
if (sum >= tree[n - 1]) { return -1; } | |
uint64_t value = sum; | |
uint32_t i = 0; | |
uint32_t mask = (uint32_t)(n >> 1); | |
while (mask != 0) { | |
uint32_t t = i + mask; | |
if (t <= n && value >= tree[t - 1]) { | |
i = t; | |
value -= tree[t - 1]; | |
} | |
mask >>= 1; | |
} | |
return i == 0 && value < sum ? -1 : (int32_t)(i - 1); | |
} | |
#ifndef countof | |
#define countof(a) (sizeof(a) / sizeof((a)[0])) | |
#endif | |
static inline uint64_t pm_sum_of(struct prob_model* pm, uint32_t sym) { | |
return ft_query(pm->tree, countof(pm->tree), sym - 1); | |
} | |
static inline uint64_t pm_total_freq(struct prob_model* pm) { | |
return pm->tree[countof(pm->tree) - 1]; | |
} | |
static inline int32_t pm_index_of(struct prob_model* pm, uint64_t sum) { | |
return ft_index_of(pm->tree, countof(pm->tree), sum) + 1; | |
} | |
void pm_init(struct prob_model* pm, uint32_t n) { | |
for (size_t i = 0; i < countof(pm->freq); i++) { | |
pm->freq[i] = i < n ? 1 : 0; | |
} | |
ft_init(pm->tree, countof(pm->tree), pm->freq); | |
} | |
void pm_update(struct prob_model* pm, uint8_t sym, uint64_t inc) { | |
if (pm->tree[countof(pm->tree) - 1] < pm_max_freq) { | |
pm->freq[sym] += inc; | |
ft_update(pm->tree, countof(pm->tree), sym, inc); | |
} | |
} | |
static void rc_emit(struct range_coder* rc) { | |
const uint8_t byte = (uint8_t)(rc->low >> 56); | |
rc->write(rc, byte); | |
rc->low <<= 8; | |
rc->range <<= 8; | |
} | |
static inline bool rc_leftmost_byte_is_same(struct range_coder* rc) { | |
return (rc->low >> 56) == ((rc->low + rc->range) >> 56); | |
} | |
void rc_init(struct range_coder* rc, uint64_t code) { | |
rc->low = 0; | |
rc->range = UINT64_MAX; | |
rc->code = code; | |
rc->error = 0; | |
} | |
static void rc_flush(struct range_coder* rc) { | |
for (int i = 0; i < sizeof(rc->low); i++) { | |
rc->range = UINT64_MAX; | |
rc_emit(rc); | |
} | |
} | |
static void rc_consume(struct range_coder* rc) { | |
const uint8_t byte = rc->read(rc); | |
rc->code = (rc->code << 8) + byte; | |
rc->low <<= 8; | |
rc->range <<= 8; | |
} | |
void rc_encode(struct range_coder* rc, struct prob_model* pm, | |
uint8_t sym) { | |
uint64_t total = pm_total_freq(pm); | |
uint64_t start = pm_sum_of(pm, sym); | |
uint64_t size = pm->freq[sym]; | |
rc->range /= total; | |
rc->low += start * rc->range; | |
rc->range *= size; | |
pm_update(pm, sym, 1); | |
while (rc_leftmost_byte_is_same(rc)) { rc_emit(rc); } | |
if (rc->range < total + 1) { | |
rc_emit(rc); | |
rc_emit(rc); | |
rc->range = UINT64_MAX - rc->low; | |
} | |
} | |
static uint8_t rc_err(struct range_coder* rc, int32_t e) { | |
rc->error = e; | |
return 0; | |
} | |
uint8_t rc_decode(struct range_coder* rc, struct prob_model* pm) { | |
uint64_t total = pm_total_freq(pm); | |
if (total < 1) { return rc_err(rc, rc_err_invalid); } | |
if (rc->range < total) { | |
rc_consume(rc); | |
rc_consume(rc); | |
rc->range = UINT64_MAX - rc->low; | |
} | |
uint64_t sum = (rc->code - rc->low) / (rc->range / total); | |
int32_t sym = pm_index_of(pm, sum); | |
if (sym < 0 || pm->freq[sym] == 0) { return rc_err(rc, rc_err_data); } | |
uint64_t start = pm_sum_of(pm, sym); | |
uint64_t size = pm->freq[sym]; | |
if (size == 0 || rc->range < total) { return rc_err(rc, rc_err_data); } | |
rc->range /= total; | |
rc->low += start * rc->range; | |
rc->range *= size; | |
pm_update(pm, (uint8_t)sym, 1); | |
while (rc_leftmost_byte_is_same(rc)) { rc_consume(rc); } | |
return (uint8_t)sym; | |
} | |
#endif // rc_implementation |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright (c) 2024, "Leo" Dmitry Kuznetsov | |
// https://github.com/leok7v/rc | |
// This code and the accompanying materials are made available under the terms | |
// of BSD-3 license, which accompanies this distribution. The full text of the | |
// license may be found at https://opensource.org/license/bsd-3-clause | |
#include "range_coder.h" | |
#define rc_implementation | |
#include "range_coder.h" | |
#include <assert.h> | |
#include <math.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include <string.h> | |
struct { // in memory io | |
uint8_t data[1024]; | |
size_t bytes; // number of bytes read by io_read() | |
size_t written; // number of bytes written by io_write() | |
uint64_t checksum; // FNV hash | |
} io; | |
static void checksum_init(void) { | |
io.checksum = 0xCBF29CE484222325uLL; // FNV offset basis | |
} | |
static void checksum_append(const uint8_t byte) { | |
io.checksum ^= byte; | |
io.checksum *= 0x100000001B3; // FNV prime | |
io.checksum ^= (io.checksum >> 32); | |
io.checksum = (io.checksum << 7) | (io.checksum >> (64 - 7)); | |
} | |
static void io_write(struct range_coder* rc, uint8_t b) { | |
if (rc->error == 0) { | |
if (io.written < countof(io.data)) { | |
checksum_append(b); | |
io.data[io.written++] = b; | |
} else { | |
rc->error = rc_err_too_big; | |
} | |
} | |
} | |
static uint8_t io_read(struct range_coder* rc) { | |
if (rc->error == 0) { | |
if (io.bytes >= io.written) { | |
rc->error = rc_err_io; | |
} else { | |
assert(io.bytes < countof(io.data)); | |
checksum_append(io.data[io.bytes]); | |
return io.data[io.bytes++]; | |
} | |
} | |
return 0; | |
} | |
static int32_t compare(const uint8_t in[], const uint8_t out[], | |
size_t n, uint64_t ecs) { | |
bool equal = ecs == io.checksum; | |
if (!equal) { | |
printf("checksum encoder: %016llX != decoder: %016llX\n", | |
ecs, io.checksum); | |
} else { | |
for (size_t i = 0; i < n; i++) { | |
if (in[i] != out[i]) { | |
printf("[%d]: %d != %d\n", (int)i, in[i], out[i]); | |
equal = false; | |
break; | |
} | |
} | |
} | |
assert(equal); // break early for debugging | |
return equal && ecs == io.checksum ? 0 : rc_err_data; | |
} | |
static uint64_t encode(struct prob_model* pm, struct range_coder* rc, | |
const uint8_t a[], size_t n, uint32_t symbols) { | |
pm_init(pm, symbols); | |
rc_init(rc, 0); | |
for (size_t i = 0; i < n && rc->error == 0; i++) { | |
rc_encode(rc, pm, a[i]); | |
} | |
rc_flush(rc); | |
assert(rc->error == 0); | |
return io.checksum; | |
} | |
static size_t decode(struct prob_model* pm, struct range_coder* rc, | |
uint8_t a[], size_t n, uint32_t symbols) { | |
io.bytes = 0; | |
checksum_init(); | |
pm_init(pm, symbols); | |
rc->code = 0; | |
for (size_t i = 0; i < sizeof(rc->code); i++) { | |
rc->code = (rc->code << 8) + rc->read(rc); | |
} | |
rc_init(rc, rc->code); | |
size_t i = 0; | |
while (i < n && rc->error == 0) { | |
uint8_t sym = rc_decode(rc, pm); | |
a[i++] = (uint8_t)sym; | |
} | |
return i; | |
} | |
static double entropy(const uint64_t a[], size_t n) { | |
double total = 0; | |
for (size_t i = 0; i < n; i++) { | |
if (a[i] > 1) { total += a[i]; } | |
} | |
double e = 0; | |
for (size_t i = 0; i < n; i++) { | |
if (a[i] > 1) { | |
double p = a[i] / total; | |
e -= p * log2(p); | |
} | |
} | |
return e; | |
} | |
int main(void) { | |
static struct range_coder coder; | |
static struct prob_model model; | |
struct range_coder* rc = &coder; | |
struct prob_model* pm = &model; | |
static const char text[] = | |
"Lorem ipsum dolor sit amet, consectetur adipiscing elit, " | |
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " | |
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris " | |
"nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in " | |
"reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla " | |
"pariatur. Excepteur sint occaecat cupidatat non proident, sunt in " | |
"culpa qui officia deserunt mollit anim id est laborum."; | |
enum { bits = 8 }; | |
enum { symbols = 1 << bits }; | |
enum { n = countof(text) - 1 }; // last byte text[countof(text)] == 0 | |
io.bytes = 0; | |
io.written = 0; | |
rc->write = io_write; | |
rc->read = io_read; | |
checksum_init(); | |
const uint8_t* in = (const uint8_t*)text; | |
uint64_t ecs = encode(pm, rc, in, n, symbols); | |
const double e = entropy(pm->freq, symbols); | |
const double bps = io.written * 8.0 / n; | |
const double percent = 100.0 * io.written * 8 / ((int64_t)n * bits); | |
printf("%lld to %lld bytes. %.1f%% bps: %.3f Shannon H: %.3f\n", | |
((uint64_t)n * bits / 8), (uint64_t)io.written, percent, bps, e); | |
uint8_t out[n]; | |
size_t k = decode(pm, rc, out, n, symbols); | |
assert(rc->error == 0 && k == n && ecs == io.checksum); | |
int32_t r = compare(in, out, n, ecs); | |
printf("decode(): %s\n", r == 0 ? "ok" : "failed"); | |
assert(r == 0); | |
return r; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment