Last active
November 19, 2020 06:08
-
-
Save louchenyao/c7f31255608b47c0687f3f82ec36ccec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cstdio> | |
#include <cstdlib> | |
template <typename T> | |
void cmp(T &a, T &b) { | |
if (a > b) { | |
std::swap(a, b); | |
} | |
} | |
constexpr uint32_t div2_roundup_to_power_of_two(uint32_t v) { | |
v = (v+1)/2; | |
v--; | |
v |= v >> 1; | |
v |= v >> 2; | |
v |= v >> 4; | |
v |= v >> 8; | |
v |= v >> 16; | |
v++; | |
return v; | |
} | |
template <typename T, int BEGIN, int END, bool REV> | |
struct Merge { | |
static void merge(T *v) { | |
const int N = END - BEGIN; | |
if (N <= 1) return; | |
if (!REV) { | |
const int RIGHT = div2_roundup_to_power_of_two(N); | |
const int LEFT = N - RIGHT; | |
#pragma unroll | |
for (int i = END - 1; i >= END - LEFT; i--) { | |
cmp(v[i-RIGHT], v[i]); | |
} | |
Merge<T, BEGIN, BEGIN + LEFT, REV>::merge(v); | |
Merge<T, BEGIN+LEFT, END, REV>::merge(v); | |
} else { | |
const int LEFT = div2_roundup_to_power_of_two(N); | |
const int RIGHT = N - LEFT; | |
#pragma unroll | |
for (int i = BEGIN; i < BEGIN + RIGHT; i++) { | |
cmp(v[i+LEFT], v[i]); | |
} | |
Merge<T, BEGIN, BEGIN + LEFT, REV>::merge(v); | |
Merge<T, BEGIN+LEFT, END, REV>::merge(v); | |
} | |
} | |
}; | |
// this constructs a bitonic sorter | |
template <typename T, int BEGIN, int END, bool REV> | |
struct ThreadSort { | |
static void sort(T *v) { | |
const int N = END - BEGIN; | |
if (N <= 1) return; | |
const int LEFT = N / 2; | |
const int RIGHT = N - LEFT; | |
ThreadSort<T, BEGIN, BEGIN+LEFT, false>::sort(v); | |
ThreadSort<T, BEGIN+LEFT, END, true>::sort(v); | |
Merge<T, BEGIN, END, REV>::merge(v); | |
} | |
}; | |
template <int N> | |
void test() { | |
for (int CASE = 0; CASE < (1 << N); CASE++) { | |
int a[N]; | |
for (int i = 0; i < N; ++i) { | |
a[i] = (CASE >> i) & 1; | |
} | |
ThreadSort<int, 0, N, false>::sort(a); | |
for (int i = 1; i < N; ++i) { | |
if (a[i] < a[i - 1]) { | |
printf("N = %d, CASE = %d, ERROR\n", N, CASE); | |
exit(1); | |
} | |
} | |
} | |
} | |
int main() { | |
// int v[5] = {1, 7, 3, 2, 4}; | |
// ThreadSort<int, 0, 5, false>::sort(v); | |
// for (int i = 0; i < 5; i++) { | |
// printf("%d ", v[i]); | |
// } | |
// printf("\n"); | |
test<4>(); | |
test<5>(); | |
test<6>(); | |
test<7>(); | |
test<8>(); | |
test<9>(); | |
test<10>(); | |
test<11>(); | |
test<12>(); | |
test<13>(); | |
test<14>(); | |
test<15>(); | |
test<16>(); | |
printf("PASS\n"); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <cstdio> | |
#include <cstdlib> | |
template <typename K, typename V, int N> | |
void cmp(K (&k)[N], V (&v)[N], int i, int j) { | |
if (k[i] > k[j]) { | |
std::swap(k[i], k[j]); | |
std::swap(v[i], v[j]); | |
} | |
} | |
constexpr uint32_t div2_roundup_to_power_of_two(uint32_t v) { | |
v = (v+1)/2; | |
v--; | |
v |= v >> 1; | |
v |= v >> 2; | |
v |= v >> 4; | |
v |= v >> 8; | |
v |= v >> 16; | |
v++; | |
return v; | |
} | |
template <typename K, typename V, int N, int BEGIN, int END, bool REV> | |
struct Merge { | |
static void merge(K (&k)[N], V (&v)[N]) { | |
const int LEN = END - BEGIN; | |
if (LEN <= 1) return; | |
if (!REV) { | |
const int RIGHT = div2_roundup_to_power_of_two(LEN); | |
const int LEFT = LEN - RIGHT; | |
#pragma unroll | |
for (int i = END - 1; i >= END - LEFT; i--) { | |
cmp(k, v, i-RIGHT, i); | |
} | |
Merge<K, V, N, BEGIN, BEGIN + LEFT, REV>::merge(k, v); | |
Merge<K, V, N, BEGIN+LEFT, END, REV>::merge(k, v); | |
} else { | |
const int LEFT = div2_roundup_to_power_of_two(LEN); | |
const int RIGHT = LEN - LEFT; | |
#pragma unroll | |
for (int i = BEGIN; i < BEGIN + RIGHT; i++) { | |
cmp(k, v, i+LEFT, i); | |
} | |
Merge<K, V, N, BEGIN, BEGIN + LEFT, REV>::merge(k, v); | |
Merge<K, V, N, BEGIN+LEFT, END, REV>::merge(k, v); | |
} | |
} | |
}; | |
// this constructs a bitonic sorter | |
template <typename K, typename V, int N, int BEGIN, int END, bool REV> | |
struct ThreadSort { | |
static void sort(K (&k)[N], V (&v)[N]) { | |
const int LEN = END - BEGIN; | |
if (LEN <= 1) return; | |
const int LEFT = LEN / 2; | |
const int RIGHT = LEN - LEFT; | |
ThreadSort<K, V, N, BEGIN, BEGIN+LEFT, false>::sort(k, v); | |
ThreadSort<K, V, N, BEGIN+LEFT, END, true>::sort(k, v); | |
Merge<K, V, N, BEGIN, END, REV>::merge(k, v); | |
} | |
}; | |
template <int N> | |
void test() { | |
for (int CASE = 0; CASE < (1 << N); CASE++) { | |
int k[N]; | |
int v[N]; | |
for (int i = 0; i < N; ++i) { | |
k[i] = (CASE >> i) & 1; | |
v[i] = i; | |
} | |
ThreadSort<int, int, N, 0, N, false>::sort(k, v); | |
for (int i = 1; i < N; ++i) { | |
if (k[i] < k[i - 1]) { | |
printf("N = %d, CASE = %d, ERROR\n", N, CASE); | |
exit(1); | |
} | |
} | |
} | |
} | |
int main() { | |
int k[5] = {1, 7, 3, 2, 4}; | |
char v[5] = {'a', 'b', 'c', 'd', 'e'}; | |
ThreadSort<int, char, 5, 0, 5, false>::sort(k, v); | |
for (int i = 0; i < 5; i++) { | |
printf("(%d %c) ", k[i], v[i]); | |
} | |
printf("\n"); | |
test<4>(); | |
test<5>(); | |
test<6>(); | |
test<7>(); | |
test<8>(); | |
test<9>(); | |
test<10>(); | |
test<11>(); | |
test<12>(); | |
test<13>(); | |
test<14>(); | |
test<15>(); | |
test<16>(); | |
printf("PASS\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment