Last active
May 31, 2021 20:13
-
-
Save xu-cheng/b2cea7111e7efc818497474970089fef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Ref: | |
// https://github.com/mobilecoinofficial/mc-oblivious/blob/master/aligned-cmov/src/cmov_impl_asm.rs | |
#include <cassert> | |
#include <cstddef> | |
#include <cstdint> | |
// cmov 64 bits data | |
inline __attribute__((always_inline)) void cmov_u64(const bool condition, | |
const uint64_t* src, | |
uint64_t* dst) | |
{ | |
uint64_t tmp = *dst; | |
asm volatile( | |
R"( | |
test %1, %1 | |
cmovnz %2, %0 | |
)" | |
: "+&r"(tmp) | |
: "r"(condition), "rm"(*src) | |
: "cc"); | |
*dst = tmp; | |
} | |
// cmov 32 bits data | |
inline __attribute__((always_inline)) void cmov_u32(const bool condition, | |
const uint32_t* src, | |
uint32_t* dst) | |
{ | |
uint32_t tmp = *dst; | |
asm volatile( | |
R"( | |
test %1, %1 | |
cmovnz %2, %0 | |
)" | |
: "+&r"(tmp) | |
: "r"(condition), "rm"(*src) | |
: "cc"); | |
*dst = tmp; | |
} | |
// cmov 16 bits data | |
inline __attribute__((always_inline)) void cmov_u16(const bool condition, | |
const uint16_t* src, | |
uint16_t* dst) | |
{ | |
uint16_t tmp = *dst; | |
asm volatile( | |
R"( | |
test %1, %1 | |
cmovnz %2, %0 | |
)" | |
: "+&r"(tmp) | |
: "r"(condition), "rm"(*src) | |
: "cc"); | |
*dst = tmp; | |
} | |
// cmov 8 bits data | |
inline __attribute__((always_inline)) void cmov_u8(const bool condition, | |
const uint8_t* src, | |
uint8_t* dst) | |
{ | |
uint64_t tmp1 = static_cast<uint64_t>(*src); | |
uint64_t tmp2 = static_cast<uint64_t>(*dst); | |
cmov_u64(condition, &tmp1, &tmp2); | |
*dst = static_cast<uint8_t>(tmp2); | |
} | |
inline __attribute__((always_inline)) bool is_aligned(const void* pointer, | |
size_t byte_count) | |
{ | |
return (uintptr_t)pointer % byte_count == 0; | |
} | |
// cmov byte array | |
inline __attribute__((always_inline)) void cmov_bytes(const bool condition, | |
const uint8_t* src, | |
uint8_t* dst, size_t len) | |
{ | |
#ifdef __AVX2__ | |
// move in 512 bits unit (if aligned) | |
if (len >= 64 && is_aligned((void*)src, 64) && is_aligned((void*)dst, 64)) { | |
size_t moves = len / 64; | |
size_t move_len = moves * 64; | |
size_t move_len2 = move_len; | |
uint64_t tmp = static_cast<uint64_t>(condition); | |
asm volatile( | |
R"( | |
neg %0 | |
vmovq %0, %%xmm2 | |
vbroadcastsd %%xmm2, %%ymm1 | |
mov %3, %0 | |
loop_%=: | |
vmovdqa -64(%1, %0), %%ymm2 | |
vpmaskmovq %%ymm2, %%ymm1, -64(%2, %0) | |
vmovdqa -32(%1, %0), %%ymm3 | |
vpmaskmovq %%ymm3, %%ymm1, -32(%2, %0) | |
sub $64, %0 | |
jnz loop_%= | |
)" | |
: "+&r"(tmp) | |
: "r"(src), "r"(dst), "rmi"(move_len2) | |
: "cc", "memory", "xmm2", "ymm1", "ymm2", "ymm3"); | |
src += move_len; | |
dst += move_len; | |
len -= move_len; | |
} | |
#endif | |
assert(is_aligned((void*)src, 8) && "src need to be 8 bytes aligned"); | |
assert(is_aligned((void*)dst, 8) && "dst need to be 8 bytes aligned"); | |
// move in 64 bits unit | |
if (len >= 8) { | |
size_t moves = len / 8; | |
size_t move_len = moves * 8; | |
uint64_t tmp = static_cast<uint64_t>(condition); | |
asm volatile( | |
R"( | |
neg %0 | |
loop_%=: | |
mov -8(%3, %1, 8), %0 | |
cmovc -8(%2, %1, 8), %0 | |
mov %0, -8(%3, %1, 8) | |
dec %1 | |
jnz loop_%= | |
)" | |
: "+&r"(tmp), "+&r"(moves) | |
: "r"(src), "r"(dst) | |
: "cc", "memory"); | |
src += move_len; | |
dst += move_len; | |
len -= move_len; | |
} | |
// move in 32 bits unit | |
if (len >= 4) { | |
uint32_t tmp; | |
asm volatile( | |
R"( | |
test %1, %1 | |
mov (%3), %0 | |
cmovnz (%2), %0 | |
mov %0, (%3) | |
)" | |
: "+&r"(tmp) | |
: "r"(condition), "r"(src), "r"(dst) | |
: "cc", "memory"); | |
src += 4; | |
dst += 4; | |
len -= 4; | |
} | |
// move in 16 bits unit | |
if (len >= 2) { | |
uint16_t tmp; | |
asm volatile( | |
R"( | |
test %1, %1 | |
mov (%3), %0 | |
cmovnz (%2), %0 | |
mov %0, (%3) | |
)" | |
: "+&r"(tmp) | |
: "r"(condition), "r"(src), "r"(dst) | |
: "cc", "memory"); | |
src += 2; | |
dst += 2; | |
len -= 2; | |
} | |
// move in 8 bits unit | |
if (len >= 1) { | |
cmov_u8(condition, src, dst); | |
} | |
} | |
/////////////////// TEST ///////////////// | |
#include <cstring> | |
#include <iomanip> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
using namespace std; | |
static bool FAILED = false; | |
struct alignas(64) Foo { | |
uint8_t data[4096]; | |
Foo(uint8_t value) { memset(this->data, value, sizeof(this->data)); } | |
bool operator==(const Foo& other) const | |
{ | |
return equal(begin(this->data), end(this->data), begin(other.data)); | |
} | |
}; | |
ostream& operator<<(ostream& out, const vector<uint8_t>& v) | |
{ | |
bool flag = false; | |
out << "["; | |
for (uint8_t x : v) { | |
if (flag) { | |
cout << ", "; | |
} else { | |
flag = true; | |
} | |
out << (int)x; | |
} | |
out << "]"; | |
return out; | |
} | |
ostream& operator<<(ostream& out, const Foo& v) | |
{ | |
bool flag = false; | |
out << "["; | |
for (uint8_t x : v.data) { | |
if (flag) { | |
cout << ", "; | |
} else { | |
flag = true; | |
} | |
out << (int)x; | |
} | |
out << "]"; | |
return out; | |
} | |
template <class T> | |
void check(const string& msg, const T& lhs, const T& rhs) | |
{ | |
ios::fmtflags flags(cout.flags()); | |
cout << msg << "\t"; | |
if (lhs == rhs) { | |
cout << "[PASS]" << endl; | |
} else { | |
cout << "[FAIL]" << endl; | |
cout << hex; | |
cout << "left:" << lhs << endl; | |
cout << "right:" << rhs << endl; | |
} | |
cout.flags(flags); | |
} | |
int main() | |
{ | |
{ | |
uint8_t v1 = 0xff; | |
uint8_t v2 = 0xee; | |
cmov_u8(false, &v1, &v2); | |
check("cmov_u8 false", v2, uint8_t(0xee)); | |
cmov_u8(true, &v1, &v2); | |
check("cmov_u8 true", v2, uint8_t(0xff)); | |
} | |
{ | |
uint16_t v1 = 0xffff; | |
uint16_t v2 = 0xeeee; | |
cmov_u16(false, &v1, &v2); | |
check("cmov_u16 false", v2, uint16_t(0xeeee)); | |
cmov_u16(true, &v1, &v2); | |
check("cmov_u16 true", v2, uint16_t(0xffff)); | |
} | |
{ | |
uint32_t v1 = 0xffffffff; | |
uint32_t v2 = 0xeeeeeeee; | |
cmov_u32(false, &v1, &v2); | |
check("cmov_u32 false", v2, uint32_t(0xeeeeeeee)); | |
cmov_u32(true, &v1, &v2); | |
check("cmov_u32 true", v2, uint32_t(0xffffffff)); | |
} | |
{ | |
uint64_t v1 = 0xffffffffffffffff; | |
uint64_t v2 = 0xeeeeeeeeeeeeeeee; | |
cmov_u64(false, &v1, &v2); | |
check("cmov_u64 false", v2, uint64_t(0xeeeeeeeeeeeeeeee)); | |
cmov_u64(true, &v1, &v2); | |
check("cmov_u64 true", v2, uint64_t(0xffffffffffffffff)); | |
} | |
{ | |
for (size_t i : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, | |
10, 16, 17, 64, 128, 256, 512, 1024, 2048, 3000}) { | |
string label("cmov_bytes " + to_string(i)); | |
vector<uint8_t> v1(i, 0xff); | |
vector<uint8_t> v2(i, 0xee); | |
cmov_bytes(false, v1.data(), v2.data(), i); | |
check(label + " false", v2, vector<uint8_t>(i, 0xee)); | |
cmov_bytes(true, v1.data(), v2.data(), i); | |
check(label + " true", v2, vector<uint8_t>(i, 0xff)); | |
} | |
} | |
{ | |
Foo v1(0xff); | |
Foo v2(0xee); | |
cmov_bytes(false, (uint8_t*)&v1, (uint8_t*)&v2, sizeof(Foo)); | |
check("cmov_bytes 64B aligned false", v2, Foo(0xee)); | |
cmov_bytes(true, (uint8_t*)&v1, (uint8_t*)&v2, sizeof(Foo)); | |
check("cmov_bytes 64B aligned true", v2, Foo(0xff)); | |
} | |
return FAILED ? 1 : 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment