Last active
October 27, 2021 17:08
-
-
Save simonlindholm/6ad4ffb124a4c80e53333f66a023faf5 to your computer and use it in GitHub Desktop.
Uninitialized memory read instrumentation for MIPS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import struct | |
import argparse | |
from collections import namedtuple | |
REG = { | |
"zero":0, | |
"at":1, | |
"v0":2, | |
"v1":3, | |
"a0":4, | |
"a1":5, | |
"a2":6, | |
"a3":7, | |
"t0":8, | |
"t1":9, | |
"t2":10, | |
"t3":11, | |
"t4":12, | |
"t5":13, | |
"t6":14, | |
"t7":15, | |
"s0":16, | |
"s1":17, | |
"s2":18, | |
"s3":19, | |
"s4":20, | |
"s5":21, | |
"s6":22, | |
"s7":23, | |
"t8":24, | |
"t9":25, | |
"k0":26, | |
"k1":27, | |
"gp":28, | |
"sp":29, | |
"fp":30, | |
"ra":31, | |
} | |
REG = namedtuple('Struct', REG.keys())(*REG.values()) | |
INSTR_NOP = b'\x00\x00\x00\x00' | |
INSTR_JAL_0 = b'\x0c\x00\x00\x00' | |
INSTR_REG = lambda idx, a, b, c, d=0: struct.pack('>I', a << 11 | b << 21 | c << 16 | d << 6 | idx) | |
INSTR_IMM = lambda idx, a, b, imm: struct.pack('>I', idx << 26 | a << 16 | b << 21 | (imm & 0xffff)) | |
INSTR_ADDIU = lambda a, b, imm: INSTR_IMM(9, a, b, imm) | |
INSTR_LW = lambda a, b, imm: INSTR_IMM(35, a, b, imm) | |
INSTR_SW = lambda a, b, imm: INSTR_IMM(43, a, b, imm) | |
INSTR_LUI = lambda r, imm: INSTR_IMM(15, r, 0, imm) | |
INSTR_OR = lambda a, b, c: INSTR_REG(37, a, b, c) | |
INSTR_ORI = lambda a, b, imm: INSTR_IMM(13, a, b, imm) | |
INSTR_BNE = lambda a, b, imm: INSTR_IMM(5, a, b, imm) | |
INSTR_LI = lambda reg, imm: INSTR_ADDIU(reg, REG.zero, imm) | |
INSTR_MOVE = lambda t, s: INSTR_OR(t, s, REG.zero) | |
EI_NIDENT = 16 | |
EI_CLASS = 4 | |
EI_DATA = 5 | |
EI_VERSION = 6 | |
EI_OSABI = 7 | |
EI_ABIVERSION = 8 | |
STN_UNDEF = 0 | |
SHN_UNDEF = 0 | |
SHN_ABS = 0xfff1 | |
SHN_COMMON = 0xfff2 | |
SHN_XINDEX = 0xffff | |
SHN_LORESERVE = 0xff00 | |
STT_NOTYPE = 0 | |
STT_OBJECT = 1 | |
STT_FUNC = 2 | |
STT_SECTION = 3 | |
STT_FILE = 4 | |
STT_COMMON = 5 | |
STT_TLS = 6 | |
STB_LOCAL = 0 | |
STB_GLOBAL = 1 | |
STB_WEAK = 2 | |
STV_DEFAULT = 0 | |
STV_INTERNAL = 1 | |
STV_HIDDEN = 2 | |
STV_PROTECTED = 3 | |
SHT_NULL = 0 | |
SHT_PROGBITS = 1 | |
SHT_SYMTAB = 2 | |
SHT_STRTAB = 3 | |
SHT_RELA = 4 | |
SHT_HASH = 5 | |
SHT_DYNAMIC = 6 | |
SHT_NOTE = 7 | |
SHT_NOBITS = 8 | |
SHT_REL = 9 | |
SHT_SHLIB = 10 | |
SHT_DYNSYM = 11 | |
SHT_INIT_ARRAY = 14 | |
SHT_FINI_ARRAY = 15 | |
SHT_PREINIT_ARRAY = 16 | |
SHT_GROUP = 17 | |
SHT_SYMTAB_SHNDX = 18 | |
SHT_MIPS_DEBUG = 0x70000005 | |
SHT_MIPS_REGINFO = 0x70000006 | |
SHT_MIPS_OPTIONS = 0x7000000d | |
SHF_WRITE = 0x1 | |
SHF_ALLOC = 0x2 | |
SHF_EXECINSTR = 0x4 | |
SHF_MERGE = 0x10 | |
SHF_STRINGS = 0x20 | |
SHF_INFO_LINK = 0x40 | |
SHF_LINK_ORDER = 0x80 | |
SHF_OS_NONCONFORMING = 0x100 | |
SHF_GROUP = 0x200 | |
SHF_TLS = 0x400 | |
R_MIPS_32 = 2 | |
R_MIPS_26 = 4 | |
R_MIPS_HI16 = 5 | |
R_MIPS_LO16 = 6 | |
class Limits: | |
def __init__(self, remaining): | |
self.remaining = remaining | |
def consume(self, name, addr, addr2, fn, extra=''): | |
if self.remaining[name] is None: | |
return True | |
self.remaining[name] -= 1 | |
if self.remaining[name] == -1: | |
print("hit limit {} just before {} (remapped to {}, near {})".format(name, hex(addr), hex(addr2), fn)) | |
if extra: | |
print(extra) | |
return self.remaining[name] >= 0 | |
class ElfHeader: | |
""" | |
typedef struct { | |
unsigned char e_ident[EI_NIDENT]; | |
Elf32_Half e_type; | |
Elf32_Half e_machine; | |
Elf32_Word e_version; | |
Elf32_Addr e_entry; | |
Elf32_Off e_phoff; | |
Elf32_Off e_shoff; | |
Elf32_Word e_flags; | |
Elf32_Half e_ehsize; | |
Elf32_Half e_phentsize; | |
Elf32_Half e_phnum; | |
Elf32_Half e_shentsize; | |
Elf32_Half e_shnum; | |
Elf32_Half e_shstrndx; | |
} Elf32_Ehdr; | |
""" | |
def __init__(self, data): | |
self.e_ident = data[:EI_NIDENT] | |
self.e_type, self.e_machine, self.e_version, self.e_entry, self.e_phoff, self.e_shoff, self.e_flags, self.e_ehsize, self.e_phentsize, self.e_phnum, self.e_shentsize, self.e_shnum, self.e_shstrndx = struct.unpack('>HHIIIIIHHHHHH', data[EI_NIDENT:]) | |
assert self.e_ident[EI_CLASS] == 1 # 32-bit | |
assert self.e_ident[EI_DATA] == 2 # big-endian | |
assert (self.e_flags >> 28) in [0, 1] # mips1 / mips2 | |
assert self.e_type == 1 # relocatable | |
assert self.e_machine == 8 # MIPS I Architecture | |
assert self.e_phoff == 0 # no program header | |
assert self.e_shoff != 0 # section header | |
assert self.e_shstrndx != SHN_UNDEF | |
def to_bin(self): | |
return self.e_ident + struct.pack('>HHIIIIIHHHHHH', self.e_type, | |
self.e_machine, self.e_version, self.e_entry, self.e_phoff, | |
self.e_shoff, self.e_flags, self.e_ehsize, self.e_phentsize, | |
self.e_phnum, self.e_shentsize, self.e_shnum, self.e_shstrndx) | |
class Symbol: | |
""" | |
typedef struct { | |
Elf32_Word st_name; | |
Elf32_Addr st_value; | |
Elf32_Word st_size; | |
unsigned char st_info; | |
unsigned char st_other; | |
Elf32_Half st_shndx; | |
} Elf32_Sym; | |
""" | |
def __init__(self, data, strtab): | |
self.st_name, self.st_value, self.st_size, self.st_info, self.st_other, self.st_shndx = struct.unpack('>IIIBBH', data) | |
assert self.st_shndx != SHN_XINDEX | |
self.bind = self.st_info >> 4 | |
self.type = self.st_info & 15 | |
self.name = strtab.lookup_str(self.st_name) | |
self.visibility = self.st_other & 3 | |
@classmethod | |
def from_parts(cls, st_name, st_value, st_size, bind, type, visibility, st_shndx, strtab): | |
st_info = (bind << 4) | type | |
return cls(struct.pack('>IIIBBH', st_name, st_value, st_size, st_info, visibility, st_shndx), strtab) | |
def has_target(self): | |
return self.st_shndx != SHN_UNDEF and self.st_shndx < SHN_LORESERVE | |
def to_bin(self): | |
return struct.pack('>IIIBBH', self.st_name, self.st_value, self.st_size, self.st_info, self.st_other, self.st_shndx) | |
class Relocation: | |
def __init__(self, data, sh_type): | |
self.sh_type = sh_type | |
if sh_type == SHT_REL: | |
self.r_offset, self.r_info = struct.unpack('>II', data) | |
else: | |
self.r_offset, self.r_info, self.r_addend = struct.unpack('>III', data) | |
self.sym_index = self.r_info >> 8 | |
self.rel_type = self.r_info & 0xff | |
def to_bin(self): | |
if self.sh_type == SHT_REL: | |
return struct.pack('>II', self.r_offset, self.r_info) | |
else: | |
return struct.pack('>III', self.r_offset, self.r_info, self.r_addend) | |
class Instr: | |
def __init__(self, data): | |
self.data = data | |
self.dec, = struct.unpack('>I', data) | |
self.opcode = self.dec >> 26 | |
self.fncode = self.dec & 0x3f | |
self.rs = (self.dec >> 21) & 0x1f | |
self.rt = (self.dec >> 16) & 0x1f | |
self.rd = (self.dec >> 11) & 0x1f | |
def is_stack_push(self): | |
return (self.dec >> 15) == ((0x27bd << 1) | 1) | |
def imm_absolute(self): | |
return self.dec & 0xffff | |
def imm_sign_extended(self): | |
x = self.dec & 0xffff | |
return x - 0x10000 if x >= 0x8000 else x | |
def is_branch(self): | |
is_float_branch = (self.opcode == 17 and ((self.dec >> 21) & 0x1f) == 0x8) | |
return self.opcode in [1,4,5,6,7,20,21,22,23] or is_float_branch | |
def is_jump(self): | |
return self.opcode in [2,3] or (self.opcode == 0 and self.fncode in [8,9]) | |
def jump_target(self): | |
# actually bit-ored by (pc + 4) & 0xf0000000, but we don't know pc | |
return (self.dec & 0x3ffffff) << 2 | |
def reads_reg(self, reg): | |
# For simplicity we can currently assume that we have a branch/jump | |
if self.opcode == 0: | |
if self.fncode in [8,9]: # jr, jalr | |
return self.rs == reg | |
elif self.opcode in [2,3]: # j, jal | |
return False | |
else: | |
assert self.is_branch() | |
# TODO | |
return False | |
assert False | |
class Section: | |
""" | |
typedef struct { | |
Elf32_Word sh_name; | |
Elf32_Word sh_type; | |
Elf32_Word sh_flags; | |
Elf32_Addr sh_addr; | |
Elf32_Off sh_offset; | |
Elf32_Word sh_size; | |
Elf32_Word sh_link; | |
Elf32_Word sh_info; | |
Elf32_Word sh_addralign; | |
Elf32_Word sh_entsize; | |
} Elf32_Shdr; | |
""" | |
def __init__(self, header, data, index): | |
self.index = index | |
self.header = header | |
self.sh_name, self.sh_type, self.sh_flags, self.sh_addr, self.sh_offset, self.sh_size, self.sh_link, self.sh_info, self.sh_addralign, self.sh_entsize = struct.unpack('>IIIIIIIIII', header) | |
assert not self.sh_flags & SHF_LINK_ORDER | |
if self.sh_entsize != 0: | |
assert self.sh_size % self.sh_entsize == 0 | |
if self.sh_type == SHT_NOBITS: | |
self.data = '' | |
else: | |
self.data = data[self.sh_offset:self.sh_offset + self.sh_size] | |
self.symbols = [] | |
self.relocated_by = [] | |
def lookup_str(self, index): | |
assert self.sh_type == SHT_STRTAB | |
to = self.data.find(b'\0', index) | |
assert to != -1 | |
return self.data[index:to].decode('utf-8') | |
def is_rel(self): | |
return self.sh_type == SHT_REL or self.sh_type == SHT_RELA | |
def header_to_bin(self): | |
if self.sh_type != SHT_NOBITS: | |
self.sh_size = len(self.data) | |
return struct.pack('>IIIIIIIIII', self.sh_name, self.sh_type, self.sh_flags, self.sh_addr, self.sh_offset, self.sh_size, self.sh_link, self.sh_info, self.sh_addralign, self.sh_entsize) | |
def late_init(self, sections): | |
if self.sh_type == SHT_SYMTAB: | |
self.init_symbols(sections) | |
elif self.is_rel(): | |
self.rel_target = sections[self.sh_info] | |
self.rel_target.relocated_by.append(self) | |
def init_symbols(self, sections): | |
assert self.sh_type == SHT_SYMTAB | |
assert self.sh_entsize == 16 | |
self.strtab = sections[self.sh_link] | |
entries = [] | |
for i in range(0, self.sh_size, self.sh_entsize): | |
s = Symbol(self.data[i:i+self.sh_entsize], self.strtab) | |
entries.append(s) | |
if s.has_target(): | |
sections[s.st_shndx].symbols.append(s) | |
self.symbol_entries = entries | |
def add_str(self, string): | |
assert self.sh_type == SHT_STRTAB | |
assert b'\0' not in string | |
ret = len(self.data) | |
self.data += string + b'\0' | |
return ret | |
def add_symbol(self, name): | |
assert self.sh_type == SHT_SYMTAB | |
name_ind = self.strtab.add_str(name) | |
s = Symbol.from_parts(st_name=name_ind, st_value=0, st_size=0, bind=STB_GLOBAL, | |
type=STT_FUNC, visibility=STV_DEFAULT, st_shndx=SHN_UNDEF, | |
strtab=self.strtab) | |
ret = len(self.symbol_entries) | |
self.symbol_entries.append(s) | |
self.data += s.to_bin() | |
return ret | |
def instrument(self, symtab, reloc_sections, poison_stack_sym, bad_load_sym, limits): | |
assert self.sh_type == SHT_PROGBITS | |
assert len(self.data) % 4 == 0 | |
poisoned_functions = 0 | |
instrumented_loads = 0 | |
addr_to_fn_symbol = {} | |
poison_blacklist = [] | |
check_blacklist = [] | |
jump_targets = set() | |
for s in self.symbols: | |
if s.type == STT_FUNC: | |
addr_to_fn_symbol[s.st_value] = s | |
jump_targets.add(s.st_value) | |
orig_instr = [Instr(self.data[addr:addr+4]) for addr in range(0, len(self.data), 4)] | |
num_instr = len(orig_instr) | |
symbol_renumbering = [None] * (num_instr + 1) | |
reloc_renumbering = [None] * num_instr | |
for i in range(num_instr): | |
if orig_instr[i].is_branch(): | |
jump_targets.add(4 * (i + orig_instr[i].imm_sign_extended() + 1)) | |
ndata = [] | |
branch_fixups = [] | |
new_relocs = [] | |
current_function = '' | |
last_was_jump = False | |
for i in range(num_instr): | |
symbol_renumbering[i] = len(ndata) | |
addr = i*4 | |
addr2 = len(ndata)*4 | |
instr = orig_instr[i] | |
fn_sym = addr_to_fn_symbol.get(addr, None) | |
if fn_sym: | |
current_function = fn_sym.name | |
# print("function", current_function) | |
reloc_renumbering[i] = len(ndata) | |
ndata.append(instr.data) | |
# Add a stack-poisoning after any instance of addiu $sp, $sp, -size | |
# with negative size (doing it before the addiu violates the ABI, | |
# for whatever that's worth). | |
if current_function not in poison_blacklist and instr.is_stack_push() and limits.consume('poison', addr, addr2, current_function): | |
poisoned_functions += 1 | |
# print("poisoning stack in function", current_function) | |
assert not last_was_jump | |
stack_size = -instr.imm_sign_extended() | |
ndata.append(INSTR_MOVE(REG.t6, REG.ra)) | |
new_relocs.append((len(ndata), poison_stack_sym)) | |
ndata.append(INSTR_JAL_0) | |
ndata.append(INSTR_LI(REG.t7, stack_size)) | |
if instr.is_branch(): | |
target = instr.imm_sign_extended() + 1 + i | |
branch_fixups.append((i, target)) | |
if instr.opcode == 35: # lw | |
reg = (instr.dec & 0x1f0000) >> 16 | |
# Loading the return register is fine, we don't need to instrument that jump | |
# (if it's poison we'll crash anyway...) | |
if reg != REG.ra and limits.consume('instr', addr, addr2, current_function): | |
last_instr = None | |
if last_was_jump: | |
# Branch delay slot, need to reorder stuff | |
assert addr not in jump_targets | |
assert not orig_instr[i-1].reads_reg(reg) | |
cur_instr = ndata.pop() | |
last_instr = ndata.pop() | |
reloc_renumbering[i] = len(ndata) | |
ndata.append(cur_instr) | |
# Use $at to construct the poison value, and to store $ra | |
# during the function call. And just in case $at is live, | |
# preserve it on the stack (ugh). | |
# If $at is the register compared against, temporarily use | |
# $s0 to store its value. | |
instrumented_loads += 1 | |
ndata.append(INSTR_ADDIU(REG.sp, REG.sp, -8)) | |
cmp_reg = REG.s0 if reg == REG.at else reg | |
save_reg = REG.s0 if reg == REG.at else REG.at | |
ndata.append(INSTR_SW(save_reg, REG.sp, 0)) | |
if reg == REG.at: | |
ndata.append(INSTR_MOVE(cmp_reg, REG.at)) | |
ndata.append(INSTR_LUI(REG.at, 0xbadb)) | |
ndata.append(INSTR_ORI(REG.at, REG.at, 0xadbd)) | |
ndata.append(INSTR_BNE(REG.at, cmp_reg, 3)) | |
ndata.append(INSTR_MOVE(REG.at, REG.ra)) | |
new_relocs.append((len(ndata), bad_load_sym)) | |
ndata.append(INSTR_JAL_0) | |
ndata.append(INSTR_NOP) | |
if reg == REG.at: | |
ndata.append(INSTR_MOVE(REG.at, cmp_reg)) | |
ndata.append(INSTR_LW(save_reg, REG.sp, 0)) | |
ndata.append(INSTR_ADDIU(REG.sp, REG.sp, 8)) | |
if last_was_jump: | |
reloc_renumbering[i - 1] = len(ndata) | |
ndata.append(last_instr) | |
ndata.append(INSTR_NOP) | |
# Might as well preserve function alignment mod 8 | |
ndata.append(INSTR_NOP) | |
if instr.is_jump() or instr.is_branch(): | |
# Branches in branch delay slots are too annoying to deal with. | |
assert not last_was_jump | |
last_was_jump = True | |
else: | |
last_was_jump = False | |
# TODO: loading floats/doubles (lwc1) | |
symbol_renumbering[num_instr] = len(ndata) | |
for (addr, target) in branch_fixups: | |
addr = reloc_renumbering[addr] | |
target = symbol_renumbering[target] | |
rel = target - addr - 1 | |
assert -2**15 <= rel < 2**15 | |
ins_dec, = struct.unpack('>I', ndata[addr]) | |
ins_dec = (ins_dec & 0xffff0000) | (rel & 0xffff) | |
ndata[addr] = struct.pack('>I', ins_dec) | |
for s in self.relocated_by: | |
s.fixup_rel_offsets(reloc_renumbering) | |
self.data = b''.join(ndata) | |
for s in reloc_sections: | |
s.fixup_rel_targets(symbol_renumbering, symtab, self.index) | |
symtab.fixup_symtab(self, symbol_renumbering) | |
assert len(self.relocated_by) > 0 | |
self.relocated_by[0].add_jal_relocations(new_relocs) | |
print("poisoned {} function stacks".format(poisoned_functions)) | |
print("instrumented {} loads".format(instrumented_loads)) | |
def fixup_symtab(self, sect, symbol_renumbering): | |
assert self.sh_type == SHT_SYMTAB | |
for s in self.symbol_entries: | |
if s.st_shndx != sect.index: | |
continue | |
assert s.st_value % 4 == 0 | |
s.st_value = symbol_renumbering[s.st_value // 4] * 4 | |
self.data = b''.join(s.to_bin() for s in self.symbol_entries) | |
def assert_rels_sane(self): | |
assert self.is_rel() | |
last_offset = None | |
for i in range(0, self.sh_size, self.sh_entsize): | |
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type) | |
assert entry.r_offset != last_offset | |
last_offset = entry.r_offset | |
def fixup_rel_offsets(self, reloc_renumbering): | |
assert self.is_rel() | |
ndata = [] | |
for i in range(0, self.sh_size, self.sh_entsize): | |
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type) | |
assert entry.r_offset % 4 == 0 | |
entry.r_offset = reloc_renumbering[entry.r_offset // 4] * 4 | |
ndata.append(entry.to_bin()) | |
self.data = b''.join(ndata) | |
def fixup_rel_targets(self, symbol_renumbering, symtab, target_index): | |
assert self.is_rel() | |
ndata = [] | |
ncode = [self.rel_target.data[i:i+4] for i in range(0, len(self.rel_target.data), 4)] | |
for i in range(0, self.sh_size, self.sh_entsize): | |
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type) | |
sym = symtab.symbol_entries[entry.sym_index] | |
if sym.st_shndx != target_index: # (including if entry.sym_index == STN_UNDEF) | |
ndata.append(entry.to_bin()) | |
continue | |
assert entry.rel_type in [R_MIPS_LO16, R_MIPS_HI16, R_MIPS_26, R_MIPS_32] | |
assert entry.r_offset % 4 == 0 | |
assert sym.st_value % 4 == 0 | |
if self.sh_type == SHT_RELA: | |
add = entry.r_addend * 4 if entry.rel_type == R_MIPS_26 else entry.r_addend | |
else: | |
ins = Instr(ncode[entry.r_offset // 4]) | |
if entry.rel_type == R_MIPS_32: | |
add = ins.dec | |
elif entry.rel_type == R_MIPS_26: | |
add = ins.jump_target() | |
else: | |
add = ins.imm_absolute() | |
if add == 0: | |
ndata.append(entry.to_bin()) | |
continue | |
# HI16 relocations need to be fixed up in combination with LO16 ones, which is | |
# complex... For simplicity, just assert the hi part is always 0. | |
assert entry.rel_type != R_MIPS_HI16 | |
assert add % 4 == 0 | |
# Relative relocation that points into the instrumented section. | |
# Handle this using the symbol renumbering table. | |
add = (symbol_renumbering[(sym.st_value + add) // 4] - symbol_renumbering[sym.st_value // 4]) * 4 | |
# Wishful thinking. (Otherwise we might have to bother with sign extension and stuff.) | |
assert add >= 0 | |
if entry.rel_type == R_MIPS_LO16: | |
assert add < 0x8000 | |
if self.sh_type == SHT_RELA: | |
entry.r_addend = add // 4 if entry.rel_type == R_MIPS_26 else add | |
else: | |
dec = ins.dec | |
if entry.rel_type == R_MIPS_32: | |
dec = add & 0xffffffff | |
elif entry.rel_type == R_MIPS_26: | |
dec = (dec & 0xfc000000) | ((add >> 2) & 0x3ffffff) | |
else: | |
dec = (dec & 0xffff0000) | (add & 0xffff) | |
ncode[entry.r_offset // 4] = struct.pack('>I', dec) | |
ndata.append(entry.to_bin()) | |
self.data = b''.join(ndata) | |
self.rel_target.data = b''.join(ncode) | |
def add_jal_relocations(self, relocs): | |
assert self.is_rel() | |
assert len(set(r[0] for r in relocs)) == len(relocs) | |
ndata = [] | |
for (pos, sym) in relocs: | |
assert sym < (1 << 24) | |
r_info = sym << 8 | R_MIPS_26 | |
r_offset = pos * 4 | |
if self.sh_type == SHT_REL: | |
nentry_data = struct.pack('>II', r_offset, r_info) | |
else: | |
nentry_data = struct.pack('>III', r_offset, r_info, 0) | |
ndata.append(nentry_data) | |
self.data += b''.join(ndata) | |
def main(): | |
parser = argparse.ArgumentParser(description="Add instrumention code for uninitialized reads to MIPS .o files.") | |
parser.add_argument('input', help="input file") | |
parser.add_argument('output', help="output file (defaults to input file)", nargs='?') | |
parser.add_argument('--poison-cap', dest='poison_n', type=int, help="poison the first N function stacks") | |
parser.add_argument('--load-cap', dest='instr_n', type=int, help="instrument the first N loads") | |
parser.add_argument('--no-drop-debug', dest='drop_debug', help="don't drop debug section", action='store_false') | |
parser.add_argument('--no-add-symbols', dest='no_add_symbols', help="don't add symbols (requires caps = 0)", action='store_true') | |
args = parser.parse_args() | |
limits = Limits({ | |
'poison': args.poison_n, | |
'instr': args.instr_n | |
}) | |
with open(args.input, 'rb') as f: | |
data = f.read() | |
assert data[:4] == b'\x7fELF' | |
outfile = open(args.output or args.input, 'wb') | |
outidx = 0 | |
def write_out(data): | |
nonlocal outidx | |
outfile.write(data) | |
outidx += len(data) | |
def pad_out(align): | |
if align and outidx % align: | |
write_out(b'\0' * (align - outidx % align)) | |
elf_header = ElfHeader(data[0:52]) | |
offset, size = elf_header.e_shoff, elf_header.e_shentsize | |
null_section = Section(data[offset:offset + size], data, 0) | |
num_sections = elf_header.e_shnum or null_section.sh_size | |
sections = [null_section] | |
for i in range(1, num_sections): | |
ind = offset + i * size | |
sections.append(Section(data[ind:ind + size], data, i)) | |
shstr = sections[elf_header.e_shstrndx] | |
symtab = None | |
for s in sections: | |
if s.sh_type == SHT_SYMTAB: | |
assert not symtab | |
symtab = s | |
assert symtab is not None | |
for s in sections: | |
s.name = shstr.lookup_str(s.sh_name) | |
s.late_init(sections) | |
if args.no_add_symbols: | |
assert args.instr_n <= 0 | |
assert args.poison_n <= 0 | |
poison_stack_sym = None | |
bad_load_sym = None | |
else: | |
poison_stack_sym = symtab.add_symbol(b'_poison_stack') | |
bad_load_sym = symtab.add_symbol(b'_bad_load') | |
reloc_sections = [s for s in sections if s.is_rel()] | |
for s in reloc_sections: | |
s.assert_rels_sane() | |
for s in sections: | |
if s.sh_type == SHT_PROGBITS and s.sh_flags & SHF_EXECINSTR: | |
s.instrument(symtab, reloc_sections, poison_stack_sym, bad_load_sym, limits) | |
if args.drop_debug: | |
print("dropped debug section") | |
sections = [s for s in sections if s.sh_type != SHT_MIPS_DEBUG] | |
elf_header.e_shnum = len(sections) | |
write_out(elf_header.to_bin()) | |
for s in sections: | |
if s.sh_type != SHT_NOBITS and s.sh_type != SHT_NULL: | |
pad_out(s.sh_addralign) | |
s.sh_offset = outidx | |
write_out(s.data) | |
pad_out(4) | |
elf_header.e_shoff = outidx | |
for s in sections: | |
write_out(s.header_to_bin()) | |
outfile.seek(0) | |
outfile.write(elf_header.to_bin()) | |
outfile.close() | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
.set noat # allow manual use of $at | |
.set noreorder # don't insert nops after branches | |
.set gp=64 | |
.include "macros.inc" | |
# spill every register except r0 and sp to the stack | |
# TODO: save/restore floating point registers (currently not necessary, but still) | |
.macro save_all | |
sw $at, 16($sp) | |
sw $v0, 20($sp) | |
sw $v1, 24($sp) | |
sw $a0, 28($sp) | |
sw $a1, 32($sp) | |
sw $a2, 36($sp) | |
sw $a3, 40($sp) | |
sw $t0, 44($sp) | |
sw $t1, 48($sp) | |
sw $t2, 52($sp) | |
sw $t3, 56($sp) | |
sw $t4, 60($sp) | |
sw $t5, 64($sp) | |
sw $t6, 68($sp) | |
sw $t7, 72($sp) | |
sw $s0, 76($sp) | |
sw $s1, 80($sp) | |
sw $s2, 84($sp) | |
sw $s3, 88($sp) | |
sw $s4, 92($sp) | |
sw $s5, 96($sp) | |
sw $s6, 100($sp) | |
sw $s7, 104($sp) | |
sw $t8, 108($sp) | |
sw $t9, 112($sp) | |
sw $k0, 116($sp) | |
sw $k1, 120($sp) | |
sw $gp, 124($sp) | |
sw $fp, 128($sp) | |
sw $ra, 132($sp) | |
mflo $at | |
sw $at, 136($sp) | |
mfhi $at | |
sw $at, 140($sp) | |
.endm | |
.macro restore_all | |
lw $ra, 132($sp) | |
lw $at, 136($sp) | |
lw $fp, 128($sp) | |
mtlo $at | |
lw $at, 140($sp) | |
lw $gp, 124($sp) | |
mthi $at | |
lw $k1, 120($sp) | |
lw $k0, 116($sp) | |
lw $t9, 112($sp) | |
lw $t8, 108($sp) | |
lw $s7, 104($sp) | |
lw $s6, 100($sp) | |
lw $s5, 96($sp) | |
lw $s4, 92($sp) | |
lw $s3, 88($sp) | |
lw $s2, 84($sp) | |
lw $s1, 80($sp) | |
lw $s0, 76($sp) | |
lw $t7, 72($sp) | |
lw $t6, 68($sp) | |
lw $t5, 64($sp) | |
lw $t4, 60($sp) | |
lw $t3, 56($sp) | |
lw $t2, 52($sp) | |
lw $t1, 48($sp) | |
lw $t0, 44($sp) | |
lw $a3, 40($sp) | |
lw $a2, 36($sp) | |
lw $a1, 32($sp) | |
lw $a0, 28($sp) | |
lw $v1, 24($sp) | |
lw $v0, 20($sp) | |
lw $at, 16($sp) | |
.endm | |
.section .text, "ax" | |
.type _poison_stack function | |
glabel _poison_stack | |
# t6 = ra | |
# t7 = size | |
lui $at, 0xbadb | |
ori $at, $at, 0xadbd | |
addu $t7, $t7, $sp | |
again: | |
addiu $t7, $t7, -4 | |
sw $at, 0($t7) | |
bne $t7, $sp, again | |
nop | |
jr $ra | |
move $ra, $t6 | |
.type _bad_load function | |
glabel _bad_load | |
# reserve space for 30 registers, hi, lo, and home space for 4 registers in the called function | |
# (30 + 2 + 4) * 4 = 144 | |
addiu $sp, $sp, -144 | |
save_all | |
move $a0, $ra | |
jal _print_bad_load | |
nop | |
restore_all | |
addiu $sp, $sp, 144 | |
jr $ra | |
move $ra, $at |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifdef STANDALONE | |
void print(const char* str); | |
void phex(unsigned int x); | |
void _print_bad_load(unsigned int ra) { | |
unsigned int addr = ra - 7 * 4; | |
print("read poison at address "); | |
phex(addr); | |
} | |
#else | |
extern unsigned gUninitMemoryReadAddr; | |
void _print_bad_load(unsigned int ra) { | |
unsigned int addr = ra - 7 * 4; | |
if (gUninitMemoryReadAddr == 0) { | |
gUninitMemoryReadAddr = addr; | |
} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment