Skip to content

Instantly share code, notes, and snippets.

@xen0n
Created July 25, 2023 12:11
Show Gist options
  • Save xen0n/f0b599727a8b54559d385abab02b6f1e to your computer and use it in GitHub Desktop.
Save xen0n/f0b599727a8b54559d385abab02b6f1e to your computer and use it in GitHub Desktop.
The verification code behind LLD's getLoongArchPageDelta
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Ever wondered how the wicked logic in https://github.com/llvm/llvm-project/blob/6084ee742064cf8/lld/ELF/Arch/LoongArch.cpp#L86-L169
// got discovered? This is the experiment I've done...
#include <cstdio>
#include <cstdint>
#include <cstdlib>
#include <unistd.h>
#include <sys/mman.h>
#include <sys/random.h>
// #define CHECK_ONE_CLI
// -or-
// #define VERIFY_MANY
constexpr bool show_distrib = false;
// -or-
#define JIT_VERIFY
// -or-
// #define GENERATE_LLD_TEST_CASES
static uint64_t randu64() {
uint64_t x;
ssize_t ret = getrandom(&x, sizeof(x), 0);
if (ret < 0)
std::abort();
return x;
}
static uint64_t getLoongArchPage(uint64_t p) {
return p & ~static_cast<uint64_t>(0xfff);
}
// Calculate the adjusted page delta between dest and PC.
uint64_t getLoongArchPageDelta(uint64_t dest, uint64_t pc) {
uint64_t result = getLoongArchPage(dest) - getLoongArchPage(pc);
// We must specially handle the cases when the low 12 bits of dest are seen
// as negative, because the instructions consuming it (ld, st, addi, etc.)
// all sign-extend the immediate.
bool signLo12 = (dest & 0xfff) > 0x7ff;
if (signLo12)
result += 0x1000;
// We have to check if the higher 32 bits need adjustment too, due to
// potential usage in patterns like:
//
// pcalau12i A, %foo_hi20(sym)
// addi.d T, zero, %foo_lo12(sym)
// lu32i.d T, %foo64_lo20(sym)
// lu52i.d T, T, %foo64_hi12(sym)
// ldx.d A, A, T
//
// in which case the "pc + hi20" part is separately constructed from the
// rest containing the higher 32-bit half and lo12, so the higher 32 bits
// need a conditional nudge too, due to the signed addition performed by the
// ldx/stx.
bool signHi20 = (result & 0x80000000) != 0;
// why oh why
#if 1
if (signLo12 && !signHi20)
result -= 0x100000000;
else if (!signLo12 && signHi20)
result += 0x100000000;
#else
if (signLo12 ^ signHi20)
result += 0x10000'0000;
#endif
return result;
}
static uint64_t pcalau12i(uint64_t pc, int32_t imm, bool verbose) {
uint64_t result = (pc & ~static_cast<uint64_t>(0xfff)) + (static_cast<uint64_t>(imm) << 12);
if (verbose)
fprintf(stderr, "pcalau12i pc=0x%016lx, %d => 0x%016lx\n", pc, imm, result);
return result;
}
static uint64_t lu32i(uint64_t rd, int32_t imm, bool verbose) {
uint64_t result = imm < 0
? (0xfff'00000'00000'000ull | ((static_cast<uint64_t>(imm) & 0xfffffull) << 32) | (rd & 0xffffffffull))
: (((static_cast<uint64_t>(imm) & 0xfffff) << 32) | (rd & 0xffffffff));
if (verbose)
fprintf(stderr, "lu32i.d 0x%016lx, %d => 0x%016lx\n", rd, imm, result);
return result;
}
static uint64_t lu52i(uint64_t rdrj, int32_t imm, bool verbose) {
uint64_t result = ((static_cast<uint64_t>(imm) & 0xfff) << 52) | (rdrj & 0xfffff'fffff'fffull);
if (verbose)
fprintf(stderr, "lu52i.d 0x%016lx, %d => 0x%016lx\n", rdrj, imm, result);
return result;
}
/*
static int32_t extractSImm(uint64_t x, int lsb, int msb) {
int width = msb - lsb + 1;
int32_t val = (x >> lsb) & ((1 << width) - 1);
int32_t sign = 1 << (width - 1);
return val >= sign ? val - (1 << width) : val;
}
*/
static char sign(int32_t x) {
//if (x == 0)
// return '0';
if (x < 0)
return '-';
return '+';
}
union luht {
uint64_t v;
struct {
int32_t l: 12;
int32_t u: 20;
int32_t h: 20;
int32_t t: 12;
};
};
static int luhtSignsFromDiff(uint64_t val, uint64_t dest) {
union luht x = { .v = val | (dest & 0xfff) };
int a = 0;
if (x.l < 0)
a |= 0x1;
if (x.u < 0)
a |= 0x2;
if (x.h < 0)
a |= 0x4;
if (x.t < 0)
a |= 0x8;
return a;
}
typedef uint64_t (*jit_check_fn_t)(void);
jit_check_fn_t codebuf;
static void prepare_codebuf(uint32_t *code, union luht offset)
{
// 0000000000000000 <checker>:
// 0: 1a000004 pcalau12i $a0, 0
// 4: 02c00005 li.d $a1, 0
// 8: 16000005 lu32i.d $a1, 0
// c: 030000a5 lu52i.d $a1, $a1, 0
// 10: 00109484 add.d $a0, $a0, $a1
// 14: 4c000020 ret
*code++ = 0x1a000004 | (offset.u & 0xfffff) << 5; // DSj20
*code++ = 0x02c00005 | (offset.l & 0xfff) << 10; // DJSk12
*code++ = 0x16000005 | (offset.h & 0xfffff) << 5; // DSj20
*code++ = 0x030000a5 | (offset.t & 0xfff) << 10; // DJSk12
*code++ = 0x00109484;
*code++ = 0x4c000020;
asm volatile("ibar 0":::"memory");
}
static bool check(uint64_t dest, uint64_t pc, bool verbose = false) {
uint64_t pageDelta = getLoongArchPageDelta(dest, pc);
/*
int32_t lo12 = extractSImm(dest, 0, 11);
int32_t hi20 = extractSImm(pageDelta, 12, 31);
int32_t higher = extractSImm(pageDelta, 32, 51);
int32_t top = extractSImm(pageDelta, 52, 63);
*/
union luht deltaLUHT = { .v = pageDelta | (dest & 0xfff) };
if (verbose)
fprintf(
stderr,
" pc = 0x%016lx\n delta = 0x%016lx\n l = %d\n u = %d\n h = %d\n t = %d\n l = 0x%03x\n u = 0x%05x\n h = 0x%05x\n t = 0x%03x\n",
pc,
pageDelta,
deltaLUHT.l,
deltaLUHT.u,
deltaLUHT.h,
deltaLUHT.t,
deltaLUHT.l & 0xfff,
deltaLUHT.u & 0xfffff,
deltaLUHT.h & 0xfffff,
deltaLUHT.t & 0xfff
);
#ifdef JIT_VERIFY
prepare_codebuf(reinterpret_cast<uint32_t *>(codebuf), deltaLUHT);
uint64_t actual = codebuf();
#else
uint64_t a = pcalau12i(pc, deltaLUHT.u, verbose);
uint64_t t = lu32i((uint64_t)deltaLUHT.l, deltaLUHT.h, verbose);
t = lu52i(t, deltaLUHT.t, verbose);
uint64_t actual = a + t;
#endif
if (actual != dest && !verbose) {
if (show_distrib)
printf("BAD:LUHT%c%c%c%c\n", sign(deltaLUHT.l), sign(deltaLUHT.u), sign(deltaLUHT.h), sign(deltaLUHT.t));
else
fprintf(
stderr,
"ERROR: actual 0x%016lx\n != dest 0x%016lx\n",
actual,
dest
);
check(dest, pc, true);
} else if (show_distrib) {
printf("OK:LUHT%c%c%c%c\n", sign(deltaLUHT.l), sign(deltaLUHT.u), sign(deltaLUHT.h), sign(deltaLUHT.t));
}
return actual == dest;
}
static void printLLDTestCase(uint64_t dest, uint64_t pc) {
uint64_t pageDelta = getLoongArchPageDelta(dest, pc);
union luht deltaLUHT = { .v = pageDelta | (dest & 0xfff) };
int caseIdx = luhtSignsFromDiff(pageDelta, dest);
printf("## page delta = 0x%016lx, page offset = 0x%03lx\n", pageDelta, dest & 0xfff);
printf("## %%pc_lo12 = 0x%03x = %d\n", deltaLUHT.l & 0xfff, deltaLUHT.l);
printf("## %%pc_hi20 = 0x%05x = %d\n", deltaLUHT.u & 0xfffff, deltaLUHT.u);
printf("## %%pc64_lo20 = 0x%05x = %d\n", deltaLUHT.h & 0xfffff, deltaLUHT.h);
printf("## %%pc64_hi12 = 0x%03x = %d\n", deltaLUHT.t & 0xfff, deltaLUHT.t);
printf("# RUN: ld.lld %%t/extreme.o --section-start=.rodata=0x%016lx --section-start=.text=0x%016lx -o %%t/extreme%d\n", dest, pc, caseIdx);
printf("# RUN: llvm-objdump -d --no-show-raw-insn %%t/extreme%d | FileCheck %%s --check-prefix=EXTREME%d\n", caseIdx, caseIdx);
printf("# EXTREME%d: addi.d $t0, $zero, %d\n", caseIdx, deltaLUHT.l);
printf("# EXTREME%d-NEXT: pcalau12i $t1, %d\n", caseIdx, deltaLUHT.u);
printf("# EXTREME%d-NEXT: lu32i.d $t0, %d\n", caseIdx, deltaLUHT.h);
printf("# EXTREME%d-NEXT: lu52i.d $t0, $t0, %d\n", caseIdx, deltaLUHT.t);
//uint64_t a = pcalau12i(pc, deltaLUHT.u, true);
//uint64_t t = lu32i((uint64_t)deltaLUHT.l, deltaLUHT.h, true);
//t = lu52i(t, deltaLUHT.t, true);
printf("\n");
}
static void printLLDTestCaseFromLUHTSigns(int luhtSigns) {
const uint64_t pc = 0x12345678;
union luht delta = {
.l = (luhtSigns & 0x1) ? 0x888 : 0x111,
.u = (luhtSigns & 0x2) ? 0x99999 : 0x22222,
.h = (luhtSigns & 0x4) ? 0xaaaaa : 0x33333,
.t = (luhtSigns & 0x8) ? 0xbbb : 0x444,
};
uint64_t pageDiff = getLoongArchPage(delta.v);
uint64_t dest = (getLoongArchPage(pc) + pageDiff) | (delta.l & 0xfff);
printLLDTestCase(dest, pc);
}
static uint64_t parse_u64(const char *argv0, const char *x) {
char *endptr;
uint64_t dest = strtoull(x, &endptr, 16);
if (endptr && *endptr != '\0') {
fprintf(stderr, "%s: invalid hexadecimal u64: %s\n", argv0, x);
exit(2);
}
return dest;
}
struct test_input {
uint64_t dest;
uint64_t pc;
};
static struct test_input generate_one_case_pc(uint64_t pc) {
struct test_input ret = {
.dest = randu64(),
.pc = pc,
};
return ret;
}
static struct test_input generate_one_case() {
return generate_one_case_pc(randu64() & ~0b11ull);
}
static uint64_t getLoongArchPageDelta(struct test_input testcase) {
return getLoongArchPageDelta(testcase.dest, testcase.pc);
}
int main(int argc, const char *argv[])
{
#if defined(CHECK_ONE_CLI)
if (argc != 3) {
fprintf(stderr, "usage: %s <0xdest> <0xpc>\n", argv[0]);
return 2;
}
uint64_t dest = parse_u64(argv[0], argv[1]);
uint64_t pc = parse_u64(argv[0], argv[2]);
// return check(0x900'10000'abcde'f00, 0x800'20345'ba987'654, true) ? 0 : 1;
return check(dest, pc, true) ? 0 : 1;
#elif defined(VERIFY_MANY)
for (int i = 0; i < 500000; i++) {
struct test_input t = generate_one_case();
if (!check(t.dest, t.pc))
if (show_distrib)
continue;
else
return 1;
}
return 0;
#elif defined(GENERATE_LLD_TEST_CASES)
// no randomness for cleaner output
#if 0
struct test_input inputs[16] = {};
int casesSeen = 0;
while (casesSeen < 16) {
struct test_input x = generate_one_case();
uint64_t pd = getLoongArchPageDelta(x);
int caseIdx = luhtSignsFromDiff(pd, x.dest);
if (!inputs[caseIdx].dest) {
inputs[caseIdx] = x;
casesSeen++;
}
}
for (struct test_input x : inputs) {
printLLDTestCase(x.dest, x.pc);
}
#else
for (int luhtSigns = 0; luhtSigns < 16; luhtSigns++) {
printLLDTestCaseFromLUHTSigns(luhtSigns);
}
#endif
return 0;
#elif defined(JIT_VERIFY)
size_t pagesize = sysconf(_SC_PAGE_SIZE);
codebuf = (jit_check_fn_t)mmap(NULL, pagesize, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, -1, 0);
if (codebuf == (void *)-1)
return 100;
for (int i = 0; i < 500000; i++) {
struct test_input t = generate_one_case_pc((uint64_t)codebuf);
if (!check(t.dest, t.pc))
return 2;
}
return 0;
#else
#error plz define at least one operating mode
#endif
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment