Created
July 25, 2023 12:11
-
-
Save xen0n/f0b599727a8b54559d385abab02b6f1e to your computer and use it in GitHub Desktop.
The verification code behind LLD's getLoongArchPageDelta
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
// Ever wondered how the wicked logic in https://github.com/llvm/llvm-project/blob/6084ee742064cf8/lld/ELF/Arch/LoongArch.cpp#L86-L169 | |
// got discovered? This is the experiment I've done... | |
#include <cstdio> | |
#include <cstdint> | |
#include <cstdlib> | |
#include <unistd.h> | |
#include <sys/mman.h> | |
#include <sys/random.h> | |
// #define CHECK_ONE_CLI | |
// -or- | |
// #define VERIFY_MANY | |
constexpr bool show_distrib = false; | |
// -or- | |
#define JIT_VERIFY | |
// -or- | |
// #define GENERATE_LLD_TEST_CASES | |
static uint64_t randu64() { | |
uint64_t x; | |
ssize_t ret = getrandom(&x, sizeof(x), 0); | |
if (ret < 0) | |
std::abort(); | |
return x; | |
} | |
static uint64_t getLoongArchPage(uint64_t p) { | |
return p & ~static_cast<uint64_t>(0xfff); | |
} | |
// Calculate the adjusted page delta between dest and PC. | |
uint64_t getLoongArchPageDelta(uint64_t dest, uint64_t pc) { | |
uint64_t result = getLoongArchPage(dest) - getLoongArchPage(pc); | |
// We must specially handle the cases when the low 12 bits of dest are seen | |
// as negative, because the instructions consuming it (ld, st, addi, etc.) | |
// all sign-extend the immediate. | |
bool signLo12 = (dest & 0xfff) > 0x7ff; | |
if (signLo12) | |
result += 0x1000; | |
// We have to check if the higher 32 bits need adjustment too, due to | |
// potential usage in patterns like: | |
// | |
// pcalau12i A, %foo_hi20(sym) | |
// addi.d T, zero, %foo_lo12(sym) | |
// lu32i.d T, %foo64_lo20(sym) | |
// lu52i.d T, T, %foo64_hi12(sym) | |
// ldx.d A, A, T | |
// | |
// in which case the "pc + hi20" part is separately constructed from the | |
// rest containing the higher 32-bit half and lo12, so the higher 32 bits | |
// need a conditional nudge too, due to the signed addition performed by the | |
// ldx/stx. | |
bool signHi20 = (result & 0x80000000) != 0; | |
// why oh why | |
#if 1 | |
if (signLo12 && !signHi20) | |
result -= 0x100000000; | |
else if (!signLo12 && signHi20) | |
result += 0x100000000; | |
#else | |
if (signLo12 ^ signHi20) | |
result += 0x10000'0000; | |
#endif | |
return result; | |
} | |
static uint64_t pcalau12i(uint64_t pc, int32_t imm, bool verbose) { | |
uint64_t result = (pc & ~static_cast<uint64_t>(0xfff)) + (static_cast<uint64_t>(imm) << 12); | |
if (verbose) | |
fprintf(stderr, "pcalau12i pc=0x%016lx, %d => 0x%016lx\n", pc, imm, result); | |
return result; | |
} | |
static uint64_t lu32i(uint64_t rd, int32_t imm, bool verbose) { | |
uint64_t result = imm < 0 | |
? (0xfff'00000'00000'000ull | ((static_cast<uint64_t>(imm) & 0xfffffull) << 32) | (rd & 0xffffffffull)) | |
: (((static_cast<uint64_t>(imm) & 0xfffff) << 32) | (rd & 0xffffffff)); | |
if (verbose) | |
fprintf(stderr, "lu32i.d 0x%016lx, %d => 0x%016lx\n", rd, imm, result); | |
return result; | |
} | |
static uint64_t lu52i(uint64_t rdrj, int32_t imm, bool verbose) { | |
uint64_t result = ((static_cast<uint64_t>(imm) & 0xfff) << 52) | (rdrj & 0xfffff'fffff'fffull); | |
if (verbose) | |
fprintf(stderr, "lu52i.d 0x%016lx, %d => 0x%016lx\n", rdrj, imm, result); | |
return result; | |
} | |
/* | |
static int32_t extractSImm(uint64_t x, int lsb, int msb) { | |
int width = msb - lsb + 1; | |
int32_t val = (x >> lsb) & ((1 << width) - 1); | |
int32_t sign = 1 << (width - 1); | |
return val >= sign ? val - (1 << width) : val; | |
} | |
*/ | |
static char sign(int32_t x) { | |
//if (x == 0) | |
// return '0'; | |
if (x < 0) | |
return '-'; | |
return '+'; | |
} | |
union luht { | |
uint64_t v; | |
struct { | |
int32_t l: 12; | |
int32_t u: 20; | |
int32_t h: 20; | |
int32_t t: 12; | |
}; | |
}; | |
static int luhtSignsFromDiff(uint64_t val, uint64_t dest) { | |
union luht x = { .v = val | (dest & 0xfff) }; | |
int a = 0; | |
if (x.l < 0) | |
a |= 0x1; | |
if (x.u < 0) | |
a |= 0x2; | |
if (x.h < 0) | |
a |= 0x4; | |
if (x.t < 0) | |
a |= 0x8; | |
return a; | |
} | |
typedef uint64_t (*jit_check_fn_t)(void); | |
jit_check_fn_t codebuf; | |
static void prepare_codebuf(uint32_t *code, union luht offset) | |
{ | |
// 0000000000000000 <checker>: | |
// 0: 1a000004 pcalau12i $a0, 0 | |
// 4: 02c00005 li.d $a1, 0 | |
// 8: 16000005 lu32i.d $a1, 0 | |
// c: 030000a5 lu52i.d $a1, $a1, 0 | |
// 10: 00109484 add.d $a0, $a0, $a1 | |
// 14: 4c000020 ret | |
*code++ = 0x1a000004 | (offset.u & 0xfffff) << 5; // DSj20 | |
*code++ = 0x02c00005 | (offset.l & 0xfff) << 10; // DJSk12 | |
*code++ = 0x16000005 | (offset.h & 0xfffff) << 5; // DSj20 | |
*code++ = 0x030000a5 | (offset.t & 0xfff) << 10; // DJSk12 | |
*code++ = 0x00109484; | |
*code++ = 0x4c000020; | |
asm volatile("ibar 0":::"memory"); | |
} | |
static bool check(uint64_t dest, uint64_t pc, bool verbose = false) { | |
uint64_t pageDelta = getLoongArchPageDelta(dest, pc); | |
/* | |
int32_t lo12 = extractSImm(dest, 0, 11); | |
int32_t hi20 = extractSImm(pageDelta, 12, 31); | |
int32_t higher = extractSImm(pageDelta, 32, 51); | |
int32_t top = extractSImm(pageDelta, 52, 63); | |
*/ | |
union luht deltaLUHT = { .v = pageDelta | (dest & 0xfff) }; | |
if (verbose) | |
fprintf( | |
stderr, | |
" pc = 0x%016lx\n delta = 0x%016lx\n l = %d\n u = %d\n h = %d\n t = %d\n l = 0x%03x\n u = 0x%05x\n h = 0x%05x\n t = 0x%03x\n", | |
pc, | |
pageDelta, | |
deltaLUHT.l, | |
deltaLUHT.u, | |
deltaLUHT.h, | |
deltaLUHT.t, | |
deltaLUHT.l & 0xfff, | |
deltaLUHT.u & 0xfffff, | |
deltaLUHT.h & 0xfffff, | |
deltaLUHT.t & 0xfff | |
); | |
#ifdef JIT_VERIFY | |
prepare_codebuf(reinterpret_cast<uint32_t *>(codebuf), deltaLUHT); | |
uint64_t actual = codebuf(); | |
#else | |
uint64_t a = pcalau12i(pc, deltaLUHT.u, verbose); | |
uint64_t t = lu32i((uint64_t)deltaLUHT.l, deltaLUHT.h, verbose); | |
t = lu52i(t, deltaLUHT.t, verbose); | |
uint64_t actual = a + t; | |
#endif | |
if (actual != dest && !verbose) { | |
if (show_distrib) | |
printf("BAD:LUHT%c%c%c%c\n", sign(deltaLUHT.l), sign(deltaLUHT.u), sign(deltaLUHT.h), sign(deltaLUHT.t)); | |
else | |
fprintf( | |
stderr, | |
"ERROR: actual 0x%016lx\n != dest 0x%016lx\n", | |
actual, | |
dest | |
); | |
check(dest, pc, true); | |
} else if (show_distrib) { | |
printf("OK:LUHT%c%c%c%c\n", sign(deltaLUHT.l), sign(deltaLUHT.u), sign(deltaLUHT.h), sign(deltaLUHT.t)); | |
} | |
return actual == dest; | |
} | |
static void printLLDTestCase(uint64_t dest, uint64_t pc) { | |
uint64_t pageDelta = getLoongArchPageDelta(dest, pc); | |
union luht deltaLUHT = { .v = pageDelta | (dest & 0xfff) }; | |
int caseIdx = luhtSignsFromDiff(pageDelta, dest); | |
printf("## page delta = 0x%016lx, page offset = 0x%03lx\n", pageDelta, dest & 0xfff); | |
printf("## %%pc_lo12 = 0x%03x = %d\n", deltaLUHT.l & 0xfff, deltaLUHT.l); | |
printf("## %%pc_hi20 = 0x%05x = %d\n", deltaLUHT.u & 0xfffff, deltaLUHT.u); | |
printf("## %%pc64_lo20 = 0x%05x = %d\n", deltaLUHT.h & 0xfffff, deltaLUHT.h); | |
printf("## %%pc64_hi12 = 0x%03x = %d\n", deltaLUHT.t & 0xfff, deltaLUHT.t); | |
printf("# RUN: ld.lld %%t/extreme.o --section-start=.rodata=0x%016lx --section-start=.text=0x%016lx -o %%t/extreme%d\n", dest, pc, caseIdx); | |
printf("# RUN: llvm-objdump -d --no-show-raw-insn %%t/extreme%d | FileCheck %%s --check-prefix=EXTREME%d\n", caseIdx, caseIdx); | |
printf("# EXTREME%d: addi.d $t0, $zero, %d\n", caseIdx, deltaLUHT.l); | |
printf("# EXTREME%d-NEXT: pcalau12i $t1, %d\n", caseIdx, deltaLUHT.u); | |
printf("# EXTREME%d-NEXT: lu32i.d $t0, %d\n", caseIdx, deltaLUHT.h); | |
printf("# EXTREME%d-NEXT: lu52i.d $t0, $t0, %d\n", caseIdx, deltaLUHT.t); | |
//uint64_t a = pcalau12i(pc, deltaLUHT.u, true); | |
//uint64_t t = lu32i((uint64_t)deltaLUHT.l, deltaLUHT.h, true); | |
//t = lu52i(t, deltaLUHT.t, true); | |
printf("\n"); | |
} | |
static void printLLDTestCaseFromLUHTSigns(int luhtSigns) { | |
const uint64_t pc = 0x12345678; | |
union luht delta = { | |
.l = (luhtSigns & 0x1) ? 0x888 : 0x111, | |
.u = (luhtSigns & 0x2) ? 0x99999 : 0x22222, | |
.h = (luhtSigns & 0x4) ? 0xaaaaa : 0x33333, | |
.t = (luhtSigns & 0x8) ? 0xbbb : 0x444, | |
}; | |
uint64_t pageDiff = getLoongArchPage(delta.v); | |
uint64_t dest = (getLoongArchPage(pc) + pageDiff) | (delta.l & 0xfff); | |
printLLDTestCase(dest, pc); | |
} | |
static uint64_t parse_u64(const char *argv0, const char *x) { | |
char *endptr; | |
uint64_t dest = strtoull(x, &endptr, 16); | |
if (endptr && *endptr != '\0') { | |
fprintf(stderr, "%s: invalid hexadecimal u64: %s\n", argv0, x); | |
exit(2); | |
} | |
return dest; | |
} | |
struct test_input { | |
uint64_t dest; | |
uint64_t pc; | |
}; | |
static struct test_input generate_one_case_pc(uint64_t pc) { | |
struct test_input ret = { | |
.dest = randu64(), | |
.pc = pc, | |
}; | |
return ret; | |
} | |
static struct test_input generate_one_case() { | |
return generate_one_case_pc(randu64() & ~0b11ull); | |
} | |
static uint64_t getLoongArchPageDelta(struct test_input testcase) { | |
return getLoongArchPageDelta(testcase.dest, testcase.pc); | |
} | |
int main(int argc, const char *argv[]) | |
{ | |
#if defined(CHECK_ONE_CLI) | |
if (argc != 3) { | |
fprintf(stderr, "usage: %s <0xdest> <0xpc>\n", argv[0]); | |
return 2; | |
} | |
uint64_t dest = parse_u64(argv[0], argv[1]); | |
uint64_t pc = parse_u64(argv[0], argv[2]); | |
// return check(0x900'10000'abcde'f00, 0x800'20345'ba987'654, true) ? 0 : 1; | |
return check(dest, pc, true) ? 0 : 1; | |
#elif defined(VERIFY_MANY) | |
for (int i = 0; i < 500000; i++) { | |
struct test_input t = generate_one_case(); | |
if (!check(t.dest, t.pc)) | |
if (show_distrib) | |
continue; | |
else | |
return 1; | |
} | |
return 0; | |
#elif defined(GENERATE_LLD_TEST_CASES) | |
// no randomness for cleaner output | |
#if 0 | |
struct test_input inputs[16] = {}; | |
int casesSeen = 0; | |
while (casesSeen < 16) { | |
struct test_input x = generate_one_case(); | |
uint64_t pd = getLoongArchPageDelta(x); | |
int caseIdx = luhtSignsFromDiff(pd, x.dest); | |
if (!inputs[caseIdx].dest) { | |
inputs[caseIdx] = x; | |
casesSeen++; | |
} | |
} | |
for (struct test_input x : inputs) { | |
printLLDTestCase(x.dest, x.pc); | |
} | |
#else | |
for (int luhtSigns = 0; luhtSigns < 16; luhtSigns++) { | |
printLLDTestCaseFromLUHTSigns(luhtSigns); | |
} | |
#endif | |
return 0; | |
#elif defined(JIT_VERIFY) | |
size_t pagesize = sysconf(_SC_PAGE_SIZE); | |
codebuf = (jit_check_fn_t)mmap(NULL, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, -1, 0); | |
if (codebuf == (void *)-1) | |
return 100; | |
for (int i = 0; i < 500000; i++) { | |
struct test_input t = generate_one_case_pc((uint64_t)codebuf); | |
if (!check(t.dest, t.pc)) | |
return 2; | |
} | |
return 0; | |
#else | |
#error plz define at least one operating mode | |
#endif | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment