Skip to content

Instantly share code, notes, and snippets.

@simonlindholm
Last active October 27, 2021 17:08
Show Gist options
  • Save simonlindholm/6ad4ffb124a4c80e53333f66a023faf5 to your computer and use it in GitHub Desktop.
Save simonlindholm/6ad4ffb124a4c80e53333f66a023faf5 to your computer and use it in GitHub Desktop.
Uninitialized memory read instrumentation for MIPS
#!/usr/bin/env python3
import sys
import struct
import argparse
from collections import namedtuple
REG = {
"zero":0,
"at":1,
"v0":2,
"v1":3,
"a0":4,
"a1":5,
"a2":6,
"a3":7,
"t0":8,
"t1":9,
"t2":10,
"t3":11,
"t4":12,
"t5":13,
"t6":14,
"t7":15,
"s0":16,
"s1":17,
"s2":18,
"s3":19,
"s4":20,
"s5":21,
"s6":22,
"s7":23,
"t8":24,
"t9":25,
"k0":26,
"k1":27,
"gp":28,
"sp":29,
"fp":30,
"ra":31,
}
REG = namedtuple('Struct', REG.keys())(*REG.values())
INSTR_NOP = b'\x00\x00\x00\x00'
INSTR_JAL_0 = b'\x0c\x00\x00\x00'
INSTR_REG = lambda idx, a, b, c, d=0: struct.pack('>I', a << 11 | b << 21 | c << 16 | d << 6 | idx)
INSTR_IMM = lambda idx, a, b, imm: struct.pack('>I', idx << 26 | a << 16 | b << 21 | (imm & 0xffff))
INSTR_ADDIU = lambda a, b, imm: INSTR_IMM(9, a, b, imm)
INSTR_LW = lambda a, b, imm: INSTR_IMM(35, a, b, imm)
INSTR_SW = lambda a, b, imm: INSTR_IMM(43, a, b, imm)
INSTR_LUI = lambda r, imm: INSTR_IMM(15, r, 0, imm)
INSTR_OR = lambda a, b, c: INSTR_REG(37, a, b, c)
INSTR_ORI = lambda a, b, imm: INSTR_IMM(13, a, b, imm)
INSTR_BNE = lambda a, b, imm: INSTR_IMM(5, a, b, imm)
INSTR_LI = lambda reg, imm: INSTR_ADDIU(reg, REG.zero, imm)
INSTR_MOVE = lambda t, s: INSTR_OR(t, s, REG.zero)
EI_NIDENT = 16
EI_CLASS = 4
EI_DATA = 5
EI_VERSION = 6
EI_OSABI = 7
EI_ABIVERSION = 8
STN_UNDEF = 0
SHN_UNDEF = 0
SHN_ABS = 0xfff1
SHN_COMMON = 0xfff2
SHN_XINDEX = 0xffff
SHN_LORESERVE = 0xff00
STT_NOTYPE = 0
STT_OBJECT = 1
STT_FUNC = 2
STT_SECTION = 3
STT_FILE = 4
STT_COMMON = 5
STT_TLS = 6
STB_LOCAL = 0
STB_GLOBAL = 1
STB_WEAK = 2
STV_DEFAULT = 0
STV_INTERNAL = 1
STV_HIDDEN = 2
STV_PROTECTED = 3
SHT_NULL = 0
SHT_PROGBITS = 1
SHT_SYMTAB = 2
SHT_STRTAB = 3
SHT_RELA = 4
SHT_HASH = 5
SHT_DYNAMIC = 6
SHT_NOTE = 7
SHT_NOBITS = 8
SHT_REL = 9
SHT_SHLIB = 10
SHT_DYNSYM = 11
SHT_INIT_ARRAY = 14
SHT_FINI_ARRAY = 15
SHT_PREINIT_ARRAY = 16
SHT_GROUP = 17
SHT_SYMTAB_SHNDX = 18
SHT_MIPS_DEBUG = 0x70000005
SHT_MIPS_REGINFO = 0x70000006
SHT_MIPS_OPTIONS = 0x7000000d
SHF_WRITE = 0x1
SHF_ALLOC = 0x2
SHF_EXECINSTR = 0x4
SHF_MERGE = 0x10
SHF_STRINGS = 0x20
SHF_INFO_LINK = 0x40
SHF_LINK_ORDER = 0x80
SHF_OS_NONCONFORMING = 0x100
SHF_GROUP = 0x200
SHF_TLS = 0x400
R_MIPS_32 = 2
R_MIPS_26 = 4
R_MIPS_HI16 = 5
R_MIPS_LO16 = 6
class Limits:
def __init__(self, remaining):
self.remaining = remaining
def consume(self, name, addr, addr2, fn, extra=''):
if self.remaining[name] is None:
return True
self.remaining[name] -= 1
if self.remaining[name] == -1:
print("hit limit {} just before {} (remapped to {}, near {})".format(name, hex(addr), hex(addr2), fn))
if extra:
print(extra)
return self.remaining[name] >= 0
class ElfHeader:
"""
typedef struct {
unsigned char e_ident[EI_NIDENT];
Elf32_Half e_type;
Elf32_Half e_machine;
Elf32_Word e_version;
Elf32_Addr e_entry;
Elf32_Off e_phoff;
Elf32_Off e_shoff;
Elf32_Word e_flags;
Elf32_Half e_ehsize;
Elf32_Half e_phentsize;
Elf32_Half e_phnum;
Elf32_Half e_shentsize;
Elf32_Half e_shnum;
Elf32_Half e_shstrndx;
} Elf32_Ehdr;
"""
def __init__(self, data):
self.e_ident = data[:EI_NIDENT]
self.e_type, self.e_machine, self.e_version, self.e_entry, self.e_phoff, self.e_shoff, self.e_flags, self.e_ehsize, self.e_phentsize, self.e_phnum, self.e_shentsize, self.e_shnum, self.e_shstrndx = struct.unpack('>HHIIIIIHHHHHH', data[EI_NIDENT:])
assert self.e_ident[EI_CLASS] == 1 # 32-bit
assert self.e_ident[EI_DATA] == 2 # big-endian
assert (self.e_flags >> 28) in [0, 1] # mips1 / mips2
assert self.e_type == 1 # relocatable
assert self.e_machine == 8 # MIPS I Architecture
assert self.e_phoff == 0 # no program header
assert self.e_shoff != 0 # section header
assert self.e_shstrndx != SHN_UNDEF
def to_bin(self):
return self.e_ident + struct.pack('>HHIIIIIHHHHHH', self.e_type,
self.e_machine, self.e_version, self.e_entry, self.e_phoff,
self.e_shoff, self.e_flags, self.e_ehsize, self.e_phentsize,
self.e_phnum, self.e_shentsize, self.e_shnum, self.e_shstrndx)
class Symbol:
"""
typedef struct {
Elf32_Word st_name;
Elf32_Addr st_value;
Elf32_Word st_size;
unsigned char st_info;
unsigned char st_other;
Elf32_Half st_shndx;
} Elf32_Sym;
"""
def __init__(self, data, strtab):
self.st_name, self.st_value, self.st_size, self.st_info, self.st_other, self.st_shndx = struct.unpack('>IIIBBH', data)
assert self.st_shndx != SHN_XINDEX
self.bind = self.st_info >> 4
self.type = self.st_info & 15
self.name = strtab.lookup_str(self.st_name)
self.visibility = self.st_other & 3
@classmethod
def from_parts(cls, st_name, st_value, st_size, bind, type, visibility, st_shndx, strtab):
st_info = (bind << 4) | type
return cls(struct.pack('>IIIBBH', st_name, st_value, st_size, st_info, visibility, st_shndx), strtab)
def has_target(self):
return self.st_shndx != SHN_UNDEF and self.st_shndx < SHN_LORESERVE
def to_bin(self):
return struct.pack('>IIIBBH', self.st_name, self.st_value, self.st_size, self.st_info, self.st_other, self.st_shndx)
class Relocation:
def __init__(self, data, sh_type):
self.sh_type = sh_type
if sh_type == SHT_REL:
self.r_offset, self.r_info = struct.unpack('>II', data)
else:
self.r_offset, self.r_info, self.r_addend = struct.unpack('>III', data)
self.sym_index = self.r_info >> 8
self.rel_type = self.r_info & 0xff
def to_bin(self):
if self.sh_type == SHT_REL:
return struct.pack('>II', self.r_offset, self.r_info)
else:
return struct.pack('>III', self.r_offset, self.r_info, self.r_addend)
class Instr:
def __init__(self, data):
self.data = data
self.dec, = struct.unpack('>I', data)
self.opcode = self.dec >> 26
self.fncode = self.dec & 0x3f
self.rs = (self.dec >> 21) & 0x1f
self.rt = (self.dec >> 16) & 0x1f
self.rd = (self.dec >> 11) & 0x1f
def is_stack_push(self):
return (self.dec >> 15) == ((0x27bd << 1) | 1)
def imm_absolute(self):
return self.dec & 0xffff
def imm_sign_extended(self):
x = self.dec & 0xffff
return x - 0x10000 if x >= 0x8000 else x
def is_branch(self):
is_float_branch = (self.opcode == 17 and ((self.dec >> 21) & 0x1f) == 0x8)
return self.opcode in [1,4,5,6,7,20,21,22,23] or is_float_branch
def is_jump(self):
return self.opcode in [2,3] or (self.opcode == 0 and self.fncode in [8,9])
def jump_target(self):
# actually bit-ored by (pc + 4) & 0xf0000000, but we don't know pc
return (self.dec & 0x3ffffff) << 2
def reads_reg(self, reg):
# For simplicity we can currently assume that we have a branch/jump
if self.opcode == 0:
if self.fncode in [8,9]: # jr, jalr
return self.rs == reg
elif self.opcode in [2,3]: # j, jal
return False
else:
assert self.is_branch()
# TODO
return False
assert False
class Section:
"""
typedef struct {
Elf32_Word sh_name;
Elf32_Word sh_type;
Elf32_Word sh_flags;
Elf32_Addr sh_addr;
Elf32_Off sh_offset;
Elf32_Word sh_size;
Elf32_Word sh_link;
Elf32_Word sh_info;
Elf32_Word sh_addralign;
Elf32_Word sh_entsize;
} Elf32_Shdr;
"""
def __init__(self, header, data, index):
self.index = index
self.header = header
self.sh_name, self.sh_type, self.sh_flags, self.sh_addr, self.sh_offset, self.sh_size, self.sh_link, self.sh_info, self.sh_addralign, self.sh_entsize = struct.unpack('>IIIIIIIIII', header)
assert not self.sh_flags & SHF_LINK_ORDER
if self.sh_entsize != 0:
assert self.sh_size % self.sh_entsize == 0
if self.sh_type == SHT_NOBITS:
self.data = ''
else:
self.data = data[self.sh_offset:self.sh_offset + self.sh_size]
self.symbols = []
self.relocated_by = []
def lookup_str(self, index):
assert self.sh_type == SHT_STRTAB
to = self.data.find(b'\0', index)
assert to != -1
return self.data[index:to].decode('utf-8')
def is_rel(self):
return self.sh_type == SHT_REL or self.sh_type == SHT_RELA
def header_to_bin(self):
if self.sh_type != SHT_NOBITS:
self.sh_size = len(self.data)
return struct.pack('>IIIIIIIIII', self.sh_name, self.sh_type, self.sh_flags, self.sh_addr, self.sh_offset, self.sh_size, self.sh_link, self.sh_info, self.sh_addralign, self.sh_entsize)
def late_init(self, sections):
if self.sh_type == SHT_SYMTAB:
self.init_symbols(sections)
elif self.is_rel():
self.rel_target = sections[self.sh_info]
self.rel_target.relocated_by.append(self)
def init_symbols(self, sections):
assert self.sh_type == SHT_SYMTAB
assert self.sh_entsize == 16
self.strtab = sections[self.sh_link]
entries = []
for i in range(0, self.sh_size, self.sh_entsize):
s = Symbol(self.data[i:i+self.sh_entsize], self.strtab)
entries.append(s)
if s.has_target():
sections[s.st_shndx].symbols.append(s)
self.symbol_entries = entries
def add_str(self, string):
assert self.sh_type == SHT_STRTAB
assert b'\0' not in string
ret = len(self.data)
self.data += string + b'\0'
return ret
def add_symbol(self, name):
assert self.sh_type == SHT_SYMTAB
name_ind = self.strtab.add_str(name)
s = Symbol.from_parts(st_name=name_ind, st_value=0, st_size=0, bind=STB_GLOBAL,
type=STT_FUNC, visibility=STV_DEFAULT, st_shndx=SHN_UNDEF,
strtab=self.strtab)
ret = len(self.symbol_entries)
self.symbol_entries.append(s)
self.data += s.to_bin()
return ret
def instrument(self, symtab, reloc_sections, poison_stack_sym, bad_load_sym, limits):
assert self.sh_type == SHT_PROGBITS
assert len(self.data) % 4 == 0
poisoned_functions = 0
instrumented_loads = 0
addr_to_fn_symbol = {}
poison_blacklist = []
check_blacklist = []
jump_targets = set()
for s in self.symbols:
if s.type == STT_FUNC:
addr_to_fn_symbol[s.st_value] = s
jump_targets.add(s.st_value)
orig_instr = [Instr(self.data[addr:addr+4]) for addr in range(0, len(self.data), 4)]
num_instr = len(orig_instr)
symbol_renumbering = [None] * (num_instr + 1)
reloc_renumbering = [None] * num_instr
for i in range(num_instr):
if orig_instr[i].is_branch():
jump_targets.add(4 * (i + orig_instr[i].imm_sign_extended() + 1))
ndata = []
branch_fixups = []
new_relocs = []
current_function = ''
last_was_jump = False
for i in range(num_instr):
symbol_renumbering[i] = len(ndata)
addr = i*4
addr2 = len(ndata)*4
instr = orig_instr[i]
fn_sym = addr_to_fn_symbol.get(addr, None)
if fn_sym:
current_function = fn_sym.name
# print("function", current_function)
reloc_renumbering[i] = len(ndata)
ndata.append(instr.data)
# Add a stack-poisoning after any instance of addiu $sp, $sp, -size
# with negative size (doing it before the addiu violates the ABI,
# for whatever that's worth).
if current_function not in poison_blacklist and instr.is_stack_push() and limits.consume('poison', addr, addr2, current_function):
poisoned_functions += 1
# print("poisoning stack in function", current_function)
assert not last_was_jump
stack_size = -instr.imm_sign_extended()
ndata.append(INSTR_MOVE(REG.t6, REG.ra))
new_relocs.append((len(ndata), poison_stack_sym))
ndata.append(INSTR_JAL_0)
ndata.append(INSTR_LI(REG.t7, stack_size))
if instr.is_branch():
target = instr.imm_sign_extended() + 1 + i
branch_fixups.append((i, target))
if instr.opcode == 35: # lw
reg = (instr.dec & 0x1f0000) >> 16
# Loading the return register is fine, we don't need to instrument that jump
# (if it's poison we'll crash anyway...)
if reg != REG.ra and limits.consume('instr', addr, addr2, current_function):
last_instr = None
if last_was_jump:
# Branch delay slot, need to reorder stuff
assert addr not in jump_targets
assert not orig_instr[i-1].reads_reg(reg)
cur_instr = ndata.pop()
last_instr = ndata.pop()
reloc_renumbering[i] = len(ndata)
ndata.append(cur_instr)
# Use $at to construct the poison value, and to store $ra
# during the function call. And just in case $at is live,
# preserve it on the stack (ugh).
# If $at is the register compared against, temporarily use
# $s0 to store its value.
instrumented_loads += 1
ndata.append(INSTR_ADDIU(REG.sp, REG.sp, -8))
cmp_reg = REG.s0 if reg == REG.at else reg
save_reg = REG.s0 if reg == REG.at else REG.at
ndata.append(INSTR_SW(save_reg, REG.sp, 0))
if reg == REG.at:
ndata.append(INSTR_MOVE(cmp_reg, REG.at))
ndata.append(INSTR_LUI(REG.at, 0xbadb))
ndata.append(INSTR_ORI(REG.at, REG.at, 0xadbd))
ndata.append(INSTR_BNE(REG.at, cmp_reg, 3))
ndata.append(INSTR_MOVE(REG.at, REG.ra))
new_relocs.append((len(ndata), bad_load_sym))
ndata.append(INSTR_JAL_0)
ndata.append(INSTR_NOP)
if reg == REG.at:
ndata.append(INSTR_MOVE(REG.at, cmp_reg))
ndata.append(INSTR_LW(save_reg, REG.sp, 0))
ndata.append(INSTR_ADDIU(REG.sp, REG.sp, 8))
if last_was_jump:
reloc_renumbering[i - 1] = len(ndata)
ndata.append(last_instr)
ndata.append(INSTR_NOP)
# Might as well preserve function alignment mod 8
ndata.append(INSTR_NOP)
if instr.is_jump() or instr.is_branch():
# Branches in branch delay slots are too annoying to deal with.
assert not last_was_jump
last_was_jump = True
else:
last_was_jump = False
# TODO: loading floats/doubles (lwc1)
symbol_renumbering[num_instr] = len(ndata)
for (addr, target) in branch_fixups:
addr = reloc_renumbering[addr]
target = symbol_renumbering[target]
rel = target - addr - 1
assert -2**15 <= rel < 2**15
ins_dec, = struct.unpack('>I', ndata[addr])
ins_dec = (ins_dec & 0xffff0000) | (rel & 0xffff)
ndata[addr] = struct.pack('>I', ins_dec)
for s in self.relocated_by:
s.fixup_rel_offsets(reloc_renumbering)
self.data = b''.join(ndata)
for s in reloc_sections:
s.fixup_rel_targets(symbol_renumbering, symtab, self.index)
symtab.fixup_symtab(self, symbol_renumbering)
assert len(self.relocated_by) > 0
self.relocated_by[0].add_jal_relocations(new_relocs)
print("poisoned {} function stacks".format(poisoned_functions))
print("instrumented {} loads".format(instrumented_loads))
def fixup_symtab(self, sect, symbol_renumbering):
assert self.sh_type == SHT_SYMTAB
for s in self.symbol_entries:
if s.st_shndx != sect.index:
continue
assert s.st_value % 4 == 0
s.st_value = symbol_renumbering[s.st_value // 4] * 4
self.data = b''.join(s.to_bin() for s in self.symbol_entries)
def assert_rels_sane(self):
assert self.is_rel()
last_offset = None
for i in range(0, self.sh_size, self.sh_entsize):
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type)
assert entry.r_offset != last_offset
last_offset = entry.r_offset
def fixup_rel_offsets(self, reloc_renumbering):
assert self.is_rel()
ndata = []
for i in range(0, self.sh_size, self.sh_entsize):
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type)
assert entry.r_offset % 4 == 0
entry.r_offset = reloc_renumbering[entry.r_offset // 4] * 4
ndata.append(entry.to_bin())
self.data = b''.join(ndata)
def fixup_rel_targets(self, symbol_renumbering, symtab, target_index):
assert self.is_rel()
ndata = []
ncode = [self.rel_target.data[i:i+4] for i in range(0, len(self.rel_target.data), 4)]
for i in range(0, self.sh_size, self.sh_entsize):
entry = Relocation(self.data[i:i+self.sh_entsize], self.sh_type)
sym = symtab.symbol_entries[entry.sym_index]
if sym.st_shndx != target_index: # (including if entry.sym_index == STN_UNDEF)
ndata.append(entry.to_bin())
continue
assert entry.rel_type in [R_MIPS_LO16, R_MIPS_HI16, R_MIPS_26, R_MIPS_32]
assert entry.r_offset % 4 == 0
assert sym.st_value % 4 == 0
if self.sh_type == SHT_RELA:
add = entry.r_addend * 4 if entry.rel_type == R_MIPS_26 else entry.r_addend
else:
ins = Instr(ncode[entry.r_offset // 4])
if entry.rel_type == R_MIPS_32:
add = ins.dec
elif entry.rel_type == R_MIPS_26:
add = ins.jump_target()
else:
add = ins.imm_absolute()
if add == 0:
ndata.append(entry.to_bin())
continue
# HI16 relocations need to be fixed up in combination with LO16 ones, which is
# complex... For simplicity, just assert the hi part is always 0.
assert entry.rel_type != R_MIPS_HI16
assert add % 4 == 0
# Relative relocation that points into the instrumented section.
# Handle this using the symbol renumbering table.
add = (symbol_renumbering[(sym.st_value + add) // 4] - symbol_renumbering[sym.st_value // 4]) * 4
# Wishful thinking. (Otherwise we might have to bother with sign extension and stuff.)
assert add >= 0
if entry.rel_type == R_MIPS_LO16:
assert add < 0x8000
if self.sh_type == SHT_RELA:
entry.r_addend = add // 4 if entry.rel_type == R_MIPS_26 else add
else:
dec = ins.dec
if entry.rel_type == R_MIPS_32:
dec = add & 0xffffffff
elif entry.rel_type == R_MIPS_26:
dec = (dec & 0xfc000000) | ((add >> 2) & 0x3ffffff)
else:
dec = (dec & 0xffff0000) | (add & 0xffff)
ncode[entry.r_offset // 4] = struct.pack('>I', dec)
ndata.append(entry.to_bin())
self.data = b''.join(ndata)
self.rel_target.data = b''.join(ncode)
def add_jal_relocations(self, relocs):
assert self.is_rel()
assert len(set(r[0] for r in relocs)) == len(relocs)
ndata = []
for (pos, sym) in relocs:
assert sym < (1 << 24)
r_info = sym << 8 | R_MIPS_26
r_offset = pos * 4
if self.sh_type == SHT_REL:
nentry_data = struct.pack('>II', r_offset, r_info)
else:
nentry_data = struct.pack('>III', r_offset, r_info, 0)
ndata.append(nentry_data)
self.data += b''.join(ndata)
def main():
parser = argparse.ArgumentParser(description="Add instrumention code for uninitialized reads to MIPS .o files.")
parser.add_argument('input', help="input file")
parser.add_argument('output', help="output file (defaults to input file)", nargs='?')
parser.add_argument('--poison-cap', dest='poison_n', type=int, help="poison the first N function stacks")
parser.add_argument('--load-cap', dest='instr_n', type=int, help="instrument the first N loads")
parser.add_argument('--no-drop-debug', dest='drop_debug', help="don't drop debug section", action='store_false')
parser.add_argument('--no-add-symbols', dest='no_add_symbols', help="don't add symbols (requires caps = 0)", action='store_true')
args = parser.parse_args()
limits = Limits({
'poison': args.poison_n,
'instr': args.instr_n
})
with open(args.input, 'rb') as f:
data = f.read()
assert data[:4] == b'\x7fELF'
outfile = open(args.output or args.input, 'wb')
outidx = 0
def write_out(data):
nonlocal outidx
outfile.write(data)
outidx += len(data)
def pad_out(align):
if align and outidx % align:
write_out(b'\0' * (align - outidx % align))
elf_header = ElfHeader(data[0:52])
offset, size = elf_header.e_shoff, elf_header.e_shentsize
null_section = Section(data[offset:offset + size], data, 0)
num_sections = elf_header.e_shnum or null_section.sh_size
sections = [null_section]
for i in range(1, num_sections):
ind = offset + i * size
sections.append(Section(data[ind:ind + size], data, i))
shstr = sections[elf_header.e_shstrndx]
symtab = None
for s in sections:
if s.sh_type == SHT_SYMTAB:
assert not symtab
symtab = s
assert symtab is not None
for s in sections:
s.name = shstr.lookup_str(s.sh_name)
s.late_init(sections)
if args.no_add_symbols:
assert args.instr_n <= 0
assert args.poison_n <= 0
poison_stack_sym = None
bad_load_sym = None
else:
poison_stack_sym = symtab.add_symbol(b'_poison_stack')
bad_load_sym = symtab.add_symbol(b'_bad_load')
reloc_sections = [s for s in sections if s.is_rel()]
for s in reloc_sections:
s.assert_rels_sane()
for s in sections:
if s.sh_type == SHT_PROGBITS and s.sh_flags & SHF_EXECINSTR:
s.instrument(symtab, reloc_sections, poison_stack_sym, bad_load_sym, limits)
if args.drop_debug:
print("dropped debug section")
sections = [s for s in sections if s.sh_type != SHT_MIPS_DEBUG]
elf_header.e_shnum = len(sections)
write_out(elf_header.to_bin())
for s in sections:
if s.sh_type != SHT_NOBITS and s.sh_type != SHT_NULL:
pad_out(s.sh_addralign)
s.sh_offset = outidx
write_out(s.data)
pad_out(4)
elf_header.e_shoff = outidx
for s in sections:
write_out(s.header_to_bin())
outfile.seek(0)
outfile.write(elf_header.to_bin())
outfile.close()
main()
.set noat # allow manual use of $at
.set noreorder # don't insert nops after branches
.set gp=64
.include "macros.inc"
# spill every register except r0 and sp to the stack
# TODO: save/restore floating point registers (currently not necessary, but still)
.macro save_all
sw $at, 16($sp)
sw $v0, 20($sp)
sw $v1, 24($sp)
sw $a0, 28($sp)
sw $a1, 32($sp)
sw $a2, 36($sp)
sw $a3, 40($sp)
sw $t0, 44($sp)
sw $t1, 48($sp)
sw $t2, 52($sp)
sw $t3, 56($sp)
sw $t4, 60($sp)
sw $t5, 64($sp)
sw $t6, 68($sp)
sw $t7, 72($sp)
sw $s0, 76($sp)
sw $s1, 80($sp)
sw $s2, 84($sp)
sw $s3, 88($sp)
sw $s4, 92($sp)
sw $s5, 96($sp)
sw $s6, 100($sp)
sw $s7, 104($sp)
sw $t8, 108($sp)
sw $t9, 112($sp)
sw $k0, 116($sp)
sw $k1, 120($sp)
sw $gp, 124($sp)
sw $fp, 128($sp)
sw $ra, 132($sp)
mflo $at
sw $at, 136($sp)
mfhi $at
sw $at, 140($sp)
.endm
.macro restore_all
lw $ra, 132($sp)
lw $at, 136($sp)
lw $fp, 128($sp)
mtlo $at
lw $at, 140($sp)
lw $gp, 124($sp)
mthi $at
lw $k1, 120($sp)
lw $k0, 116($sp)
lw $t9, 112($sp)
lw $t8, 108($sp)
lw $s7, 104($sp)
lw $s6, 100($sp)
lw $s5, 96($sp)
lw $s4, 92($sp)
lw $s3, 88($sp)
lw $s2, 84($sp)
lw $s1, 80($sp)
lw $s0, 76($sp)
lw $t7, 72($sp)
lw $t6, 68($sp)
lw $t5, 64($sp)
lw $t4, 60($sp)
lw $t3, 56($sp)
lw $t2, 52($sp)
lw $t1, 48($sp)
lw $t0, 44($sp)
lw $a3, 40($sp)
lw $a2, 36($sp)
lw $a1, 32($sp)
lw $a0, 28($sp)
lw $v1, 24($sp)
lw $v0, 20($sp)
lw $at, 16($sp)
.endm
.section .text, "ax"
.type _poison_stack function
glabel _poison_stack
# t6 = ra
# t7 = size
lui $at, 0xbadb
ori $at, $at, 0xadbd
addu $t7, $t7, $sp
again:
addiu $t7, $t7, -4
sw $at, 0($t7)
bne $t7, $sp, again
nop
jr $ra
move $ra, $t6
.type _bad_load function
glabel _bad_load
# reserve space for 30 registers, hi, lo, and home space for 4 registers in the called function
# (30 + 2 + 4) * 4 = 144
addiu $sp, $sp, -144
save_all
move $a0, $ra
jal _print_bad_load
nop
restore_all
addiu $sp, $sp, 144
jr $ra
move $ra, $at
#ifdef STANDALONE
void print(const char* str);
void phex(unsigned int x);
void _print_bad_load(unsigned int ra) {
unsigned int addr = ra - 7 * 4;
print("read poison at address ");
phex(addr);
}
#else
extern unsigned gUninitMemoryReadAddr;
void _print_bad_load(unsigned int ra) {
unsigned int addr = ra - 7 * 4;
if (gUninitMemoryReadAddr == 0) {
gUninitMemoryReadAddr = addr;
}
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment