Last active
January 11, 2023 15:22
-
-
Save vurtun/1c84e0956ab02685e8d74e01aadbdf3b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ref: http://www.codercorner.com/RadixSortRevisited.htm | |
// http://stereopsis.com/radix.html | |
// int/float: https://github.com/lshamis/FloatRadixSort | |
// string: https://github.com/rantala/string-sorting/blob/master/src/msd_ce.cpp | |
struct str { | |
const char *str; | |
const char *end; | |
int len; | |
}; | |
/* --------------------------------------------------------------------------- | |
* Sort | |
* --------------------------------------------------------------------------- | |
*/ | |
#include <stdio.h> | |
#include <assert.h> | |
#include <string.h> | |
#define cast(t, p) ((t)(p)) | |
#define for_nstep(i,n,s) for (int i = 0; i < (n); i += (s)) | |
#define for_cnt(i,n) for_nstep(i,n,1) | |
#define fori_cnt(i,n) for (i = 0; i < (n); i += 1) | |
// clang-format off | |
typedef unsigned(*sort_conv_f)(const void *p); | |
typedef void*(sort_access_f)(const void *data, void *usr); | |
#define sort__access(a,usr,access,conv,off) ((access) ? (conv)((access)(a + off, usr)) : (conv)(a + off)) | |
#define sort__char_at(s,d) (((d) < (s)->len) ? (s)->str[d] : -1) | |
#define sort__str_get(a,access,usr) (struct str*)((access) ? (access(a, usr)) : (a)) | |
static inline unsigned sort__cast_ushort(const void *p) {return *(const unsigned short*)p;} | |
static inline unsigned sort__cast_short(const void *p) {union bit_castu {short i; unsigned short u;} v = {.i = *(const short*)p}; return v.u ^ (1u << 15u);} | |
static inline unsigned sort__cast_uint(const void *p) {return *(const unsigned*)p;} | |
static inline unsigned sort__cast_int(const void *p) {union bit_castu {int i; unsigned u;} v = {.i = *(const int*)p}; return v.u ^ (1u << 31u);} | |
static inline unsigned sort__cast_flt(const void *p) {union bit_castu {float f; unsigned u;} v = {.f = *(const float*)p}; if ((v.u >> 31u) == 1u) {v.u *= (unsigned)-1; v.u ^= (1u << 31u);}return v.u ^ (1u << 31u);} | |
#define sort_short(out,a,siz,n,off) sort_radix16(out,a,siz,n,off,0,0,sort__cast_short) | |
#define sort_ushort(out,a,siz,n,off) sort_radix16(out,a,siz,n,off,0,0,sort__cast_ushort) | |
#define sort_int(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_int) | |
#define sort_uint(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_uint) | |
#define sort_flt(out,a,siz,n,off) sort_radix32(out,a,siz,n,off,0,0,sort__cast_flt) | |
#define sort_str(out,a,n,siz,off) sort__str(out,a,n,siz,off,0,0) | |
#define sort_shorts(out,a,n) sort_short(out,a,szof(short),n,0) | |
#define sort_ushorts(out,a,n) sort_ushort(out,a,szof(unsigned short),n,0) | |
#define sort_ints(out,a,n) sort_int(out,a,szof(int), n,0) | |
#define sort_uints(out,a,n) sort_uint(out,a,szof(unsigned),n,0) | |
#define sort_flts(out,a,n) sort_flt(out,a,szof(float),n,0) | |
// clang-format on | |
static void | |
sort__radix16(unsigned *restrict out, const void *a, int siz, int n, int off, | |
void *usr, sort_access_f access, sort_conv_f conv) { | |
assert(a); | |
assert(out); | |
/* <!> out needs to be at least size: 2*n+512 <!> */ | |
unsigned *buf = out + 2 * n; | |
unsigned *restrict h[] = {buf, buf + 256}; | |
const unsigned char *b = cast(const unsigned char*, a); | |
const unsigned char *e = cast(const unsigned char*, a) + n * siz; | |
/* build histogram */ | |
int is_sorted = 1; | |
memset(buf, 0, 512 * sizeof(unsigned)); | |
unsigned last = sort__access(b, usr, access, conv, off); | |
for (const unsigned char *it = b; it < e; it += siz) { | |
unsigned k = sort__access(it, usr, access, conv, off); | |
is_sorted = (k < last) ? 0 : is_sorted; | |
h[0][k & 0xff]++; | |
h[1][(k >> 8) & 0xff]++; | |
last = k; | |
} | |
if (is_sorted) { | |
return; /* already sorted so early out */ | |
} | |
/* convert histogram into offset table */ | |
unsigned sum[2] = {0}; | |
for_cnt(i,256) { | |
unsigned t0 = h[0][i] + sum[0]; h[0][i] = sum[0], sum[0] = t0; | |
unsigned t1 = h[1][i] + sum[1]; h[1][i] = sum[1], sum[1] = t1; | |
} | |
/* sort 8-bit at a time */ | |
unsigned *restrict idx[] = {out, out + n}; | |
for (int p = 0, d = 1, s = 0; p < 2; ++p, d = !d, s = !s) { | |
for (unsigned i = 0u; i != cast(unsigned,n); ++i) { | |
unsigned at = idx[s][i]; | |
unsigned k = sort__access(b + at * (unsigned)siz, usr, access, conv, off); | |
idx[d][h[p][(k>>(8*p))&0xff]++] = at; | |
} | |
} | |
} | |
static void | |
sort_radix16(unsigned *restrict out, const void *a, int siz, int n, int off, | |
void *usr, sort_access_f access, sort_conv_f conv) { | |
assert(a); | |
assert(out); | |
for_cnt(i,n) {out[i] = cast(unsigned,i);} | |
sort__radix16(out, a, siz, n, off, usr, access, conv); | |
} | |
static void | |
sort__radix32(unsigned *restrict out, const void *a, int siz, int n, int off, | |
void *usr, sort_access_f access, sort_conv_f conv) { | |
assert(a); | |
assert(out); | |
/* <!> out needs to be at least size: 2*n+1024 <!> */ | |
unsigned *buf = out + 2 * n; | |
unsigned *restrict h[] = {buf, buf + 256, buf + 512, buf + 768}; | |
const unsigned char *b = cast(const unsigned char*, a); | |
const unsigned char *e = cast(const unsigned char*, a) + n * siz; | |
/* build histogram */ | |
int is_sorted = 1; | |
memset(buf,0,1024*sizeof(unsigned)); | |
unsigned last = sort__access(b, usr, access, conv, off); | |
for (const unsigned char *it = b; it < e; it += siz) { | |
unsigned k = sort__access(it, usr, access, conv, off); | |
is_sorted = (k < last) ? 0 : is_sorted; | |
h[0][(k & 0xff)]++; | |
h[1][(k >> 8) & 0xff]++; | |
h[2][(k >> 16) & 0xff]++; | |
h[3][(k >> 24)]++; | |
last = k; | |
} | |
if (is_sorted) { | |
return; /* already sorted so early out */ | |
} | |
/* convert histogram into offset table */ | |
unsigned sum[4] = {0}; | |
for_cnt(i,256) { | |
unsigned t0 = h[0][i] + sum[0]; h[0][i] = sum[0]; sum[0] = t0; | |
unsigned t1 = h[1][i] + sum[1]; h[1][i] = sum[1]; sum[1] = t1; | |
unsigned t2 = h[2][i] + sum[2]; h[2][i] = sum[2]; sum[2] = t2; | |
unsigned t3 = h[3][i] + sum[3]; h[3][i] = sum[3]; sum[3] = t3; | |
} | |
/* sort 8-bit at a time */ | |
unsigned *restrict idx[] = {out, out + n}; | |
for (int p = 0, d = 1, s = 0; p < 4; ++p, d = !d, s = !s) { | |
for (unsigned i = 0u; i != cast(unsigned,n); ++i) { | |
unsigned at = idx[s][i]; | |
unsigned k = sort__access(b + at * (unsigned)siz, usr, access, conv, off); | |
idx[d][h[p][(k>>(8*p))&0xff]++] = at; | |
} | |
} | |
} | |
static void | |
sort_radix32(unsigned *restrict out, const void *a, int siz, int n, int off, | |
void *usr, sort_access_f access, sort_conv_f conv) { | |
assert(a); | |
assert(out); | |
for_cnt(i,n) {out[i] = cast(unsigned,i);} | |
sort__radix32(out, a, siz, n, off, usr, access, conv); | |
} | |
static char | |
sort__str_at(unsigned char *p, int d, sort_access_f access, void *usr) { | |
struct str * s = sort__str_get(p, access, usr); | |
return sort__char_at(s,d); | |
} | |
static void | |
sort__str_q3s(int *rnk, void *a, int lo, int hi, int d, int siz, int off, | |
sort_access_f access, void *usr) { | |
if (hi <= lo) return; | |
unsigned char *p = a; | |
int lt = lo, gt = hi, i = lo + 1; | |
int v = sort__str_at(p + rnk[lo] * siz + off, d, access, usr); | |
while (i <= gt) { | |
int t = sort__str_at(p + rnk[i] * siz + off, d, access, usr); | |
if (t < v) {int tmp = rnk[lt]; rnk[lt++] = rnk[i]; rnk[i++] = tmp;} | |
else if(t > v) {int tmp = rnk[i]; rnk[i] = rnk[gt]; rnk[gt--] = tmp;} | |
else i++; | |
} | |
sort__str_q3s(rnk, a, lo, lt-1, d, siz, off, access, usr); | |
if (v >= 0) sort__str_q3s(rnk, a, lt, gt, d + 1, siz, off, access, usr); | |
sort__str_q3s(rnk, a, gt+1, hi, d, siz, off, access, usr); | |
} | |
static int* | |
sort__str_base(int *rnk, int *rnk2, short *oracle, | |
void *a, int n, int siz, int off, int lo, int hi, | |
sort_access_f access, void *usr, int d) { | |
unsigned char * p = a; | |
if (n < 32) { | |
sort__str_q3s(rnk, a, lo, hi, d, siz, off, access, usr); | |
return rnk; | |
} | |
int c[257] = {0}; | |
for (int i = 0; i < n; ++i) | |
oracle[i] = sort__str_at(p + rnk[i] * siz + off, d, access, usr); | |
for (int i = 0; i < n; ++i) | |
++c[oracle[i] + 1]; | |
int idx[257]; idx[0] = idx[1] = 0; | |
for (int i = 1; i < 256; ++i) | |
idx[i+1] = idx[i] + c[i]; | |
for (int i = 0; i < n; ++i) | |
rnk2[idx[oracle[i]+1]++] = rnk[i]; | |
int *tmp = rnk; rnk = rnk2; rnk2 = tmp; | |
int bsum = c[1]; | |
for (int i = 1; i < 256; ++i) { | |
if (c[i + 1] == 0) continue; | |
int *res = sort__str_base(rnk, rnk2, oracle, a, c[i+1], siz, off, bsum, bsum + c[i+1]-1, access, usr, d+1); | |
if (res != rnk) { | |
tmp = rnk; rnk = rnk2; rnk2 = tmp; | |
} | |
bsum += c[i+1]; | |
} | |
return rnk; | |
} | |
static int* | |
sort__str(int *rnk, int *rnk2, short *oracle, | |
void *a, int n, int siz, int off, | |
sort_access_f access, void *usr) { | |
for (int i = 0; i < n; ++i) rnk[i] = i; | |
return sort__str_base(rnk, rnk2, oracle, a, n, siz, off, 0, n-1, access, usr, 0); | |
} | |
/* --------------------------------------------------------------------------- | |
* Test | |
* --------------------------------------------------------------------------- | |
*/ | |
#include <stdio.h> | |
#define unused(a) ((void)a) | |
#define cast(t, p) ((t)(p)) | |
#define szof(a) ((int)sizeof(a)) | |
#define cntof(a) ((int)(sizeof(a) / sizeof((a)[0]))) | |
#define str(s,n) (struct str){s, (s) + (n), (n)} | |
#define strv(s) str(s, cntof(s)-1) | |
int main(void) { | |
#if 1 | |
int arr[] = { 256, -36789, 170, 45, 75, 1765987, 90, 802, -24, 2, -66, 17895 }; | |
int sorted[cntof(arr) * 2 + 1024]; | |
sort_ints(sorted, arr, cntof(arr)); | |
for (int i = 0; i < cntof(arr); ++i) { | |
printf("%d\n", arr[sorted[i]]); | |
} | |
#elif 0 | |
float arr[] = { 170.0f, 0.001f, -0.05f, 20.0f, -30.0f, 802.0f, 2.5f, 2000.65f, -12.54f, 66.0f }; | |
int sorted[cntof(arr) * 2 + 1024]; | |
sort_flts(sorted, arr, cntof(arr)); | |
for (int i = 0; i < cntof(arr); ++i) { | |
printf("%.3f\n", arr[sorted[i]]); | |
} | |
#else | |
struct str arr[] = { | |
strv("aaaba"), | |
strv("dfjasdlifjai"), | |
strv("jiifjeogiejogp"), | |
strv("aabaaaa"), | |
strv("gsgj"), | |
strv("gerph"), | |
strv("aaaaaaa"), | |
strv("htjltjlrth"), | |
strv("joasdjfisdjfdo"), | |
strv("hthe"), | |
strv("aaaaaba"), | |
strv("jrykpjl"), | |
strv("hkoptjltp"), | |
strv("aaaaaa"), | |
strv("lprrjt") | |
}; | |
short tmp3[cntof(arr)]; | |
int tmp[cntof(arr)], tmp2[cntof(arr)]; | |
int *idx = sort_str(tmp, tmp2, tmp3, arr, cntof(arr), sizeof(arr[0]), 0); | |
for (int i = 0; i < cntof(arr); ++i) { | |
printf("%.*s\n", arr[idx[i]].len, arr[idx[i]].str); | |
} | |
#endif | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment