Skip to content

Instantly share code, notes, and snippets.

@jamchamb
Last active April 14, 2023 13:59
Show Gist options
  • Save jamchamb/243e6973aeb5c9a2e302a4d4f57f16e1 to your computer and use it in GitHub Desktop.
Save jamchamb/243e6973aeb5c9a2e302a4d4f57f16e1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import struct
# little endian, 32-bit words
# Could check for magic 0x01020304 / 0x04030201 to determine endianness
ENDI = '<'
WORD_PAT = ENDI + 'I'
WORD_SZ = struct.calcsize(WORD_PAT)
def get_word(buf, pos):
word = struct.unpack(
WORD_PAT,
buf[pos:pos + WORD_SZ])[0]
return word
def update_word(buf, pos, word):
buf[pos:pos + WORD_SZ] = \
struct.pack(WORD_PAT, word)
def find_words(image_data, start, end, condition, limit=0):
"""Returns list of (address, word) tuples where
word satisfies condition lambda"""
results = []
for i in range(start, end, WORD_SZ):
cur_word = get_word(image_data, i)
if condition(i, cur_word):
results.append((i, cur_word))
if limit > 0 and len(results) >= limit:
break
if len(results) == 0:
return None
return results
def find_lc0(image_data):
end_search = min(0x500, len(image_data))
candidates = find_words(
image_data,
0, end_search,
lambda addr, word: addr == word,
limit=1)
if candidates is not None:
lc0_pos, lc0_val = candidates[0]
return lc0_pos
return None
def get_lc0(image_data):
lc0_start = find_lc0(image_data)
if lc0_start is None:
raise Exception("Couldn't find LC0")
lc0_pattern = ENDI + ('I' * 9)
lc0_size = struct.calcsize(lc0_pattern)
lc0_data = struct.unpack(
lc0_pattern,
image_data[lc0_start:lc0_start + lc0_size])
return lc0_data
def get_got(image_data, got_start, got_end):
got_count = (got_end - got_start) // WORD_SZ
got_pattern = ENDI + ('I' * got_count)
got_entries = struct.unpack(
got_pattern,
image_data[got_start:got_end])
return got_entries
def find_magic(image_data, lc0_start):
zimage_magic = 0x16F2818
magic_sig_cand = find_words(
image_data,
0, lc0_start,
lambda addr, word: word == zimage_magic,
limit=1)
if magic_sig_cand is None:
raise Exception('zimage magic not found')
return magic_sig_cand[0][0]
def pack_u32_table(u32_table):
u32_pattern = ENDI + ('I' * len(u32_table))
return struct.pack(
u32_pattern,
*u32_table)
def fixup_ptrs(table, extend_size, piggy_start):
for i in range(len(table)):
if table[i] > piggy_start:
table[i] += extend_size
return table
def hexlist(dataz):
for i, entry in enumerate(list(dataz)):
print(f' {i:#04x}: {entry:#010x}')
def main():
parser = argparse.ArgumentParser(
description="""Replace or extend compressed vmlinux in
ARM kernel zImages""")
parser.add_argument('zimage', type=str,
help='path of zImage to modify')
parser.add_argument('outfile', type=str,
help='path of output zImage')
mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument(
'--extend', type=int,
help='amount to extend piggy area')
mode_group.add_argument(
'--replace', type=str, default=None,
help="""replacement piggy data file.
remember to append the inflated size 32 bit word!""")
args = parser.parse_args()
with open(args.zimage, 'rb') as zimage:
image_data = zimage.read()
# Find LC0, GOT, and piggy
lc0_data = get_lc0(image_data)
lc0_start = lc0_data[0]
lc0_end = lc0_start + (WORD_SZ * len(lc0_data))
print(f'LC0 @ {lc0_start:#06x} - {lc0_end:#06x}')
hexlist(lc0_data)
got_start = lc0_data[5]
got_end = lc0_data[6]
got_table = get_got(image_data, got_start, got_end)
print(f'GOT @ {got_start:#010x} - {got_end:#010x}')
hexlist(got_table)
piggy_start = got_table[3]
piggy_end = got_table[1]
piggy_size = piggy_end - piggy_start
piggy_infl_sz_ptr = lc0_data[4]
piggy_inflated_size = get_word(image_data, piggy_infl_sz_ptr)
print(f'piggy data @ {piggy_start:#010x} - {piggy_end:#010x}')
print(f'piggy compressed size: {piggy_size:#010x}')
print(f'piggy inflated size @ {piggy_infl_sz_ptr:#010x}')
print(f'piggy inflated size: {piggy_inflated_size:#010x}')
# Determine amount to increase image size
replace_piggy = None
if args.extend is not None:
# Zero-fill extend
incsize = args.extend
if incsize % 4 != 0:
raise Exception('piggy extend size must be multiple of 4')
elif args.replace is not None:
# Insert new piggy data from file
with open(args.replace, 'rb') as piggyf:
replace_piggy = piggyf.read()
# Help keep it 4 byte aligned
if len(replace_piggy) % 4 != 0:
raise Exception('replacement piggy size must be multiple of 4')
print(f'piggy new compressed size: {len(replace_piggy):#010x}')
if len(replace_piggy) > piggy_size:
incsize = len(replace_piggy) - piggy_size
else:
incsize = 0
else:
print('must have extend size or replacement piggy data')
return
# fixup_ptrs does a simple offset, so figure out where inflate size
# and piggy end will really be located first
if replace_piggy is not None:
# For a replacement piggy, set inflated size ptr accordingly.
# It should be at the end of the XZ data.
new_pig_end = piggy_start + len(replace_piggy)
new_pig_sz_ptr = new_pig_end - 4
else:
# Piggy inflated size stays in place for simple zero-fill
new_pig_end = piggy_end
new_pig_sz_ptr = piggy_infl_sz_ptr
# Update LC0 and GOT for extended image size
lc0_data_ext = fixup_ptrs(list(lc0_data), incsize, piggy_start)
got_table_ext = fixup_ptrs(list(got_table), incsize, piggy_start)
# Set real piggy inflated size and end pointers
lc0_data_ext[4] = new_pig_sz_ptr
got_table_ext[1] = new_pig_end
print(f'extending image by {incsize:#010x}')
print('LC0 extended:')
hexlist(lc0_data_ext)
print('GOT extended:')
hexlist(got_table_ext)
# Making the updated image now
new_image = bytearray(image_data)
# Compiled functions use an offset to the GOT.
# We need to find and update the base offsets.
# Rough minimum possible offset is from where the code ends
# to the GOT start. (got_start - piggy_start)
# Rough maximum offset is from code after LC0
# to the GOT start. (got_start - lc0_end)
print('Searching for GOT offsets...')
goto_candidates = find_words(
new_image,
lc0_end, piggy_start,
lambda addr, word: word >= (got_start - piggy_start) and
word <= (got_start - lc0_end))
for addr, word in goto_candidates:
print(f'Candidate GOT offset @ {addr:#06x}: {word:#010x}')
update_word(new_image, addr, word + incsize)
# Update LC0
lc0_ext_packed = pack_u32_table(lc0_data_ext)
new_image[lc0_start:lc0_end] = lc0_ext_packed
# Update GOT
got_ext_packed = pack_u32_table(got_table_ext)
new_image[got_start:got_end] = got_ext_packed
# Update _magic_end
magic_sig_pos = find_magic(new_image, lc0_start)
magic_start_pos = magic_sig_pos + WORD_SZ
magic_end_pos = magic_sig_pos + WORD_SZ * 2
magic_start = get_word(new_image, magic_start_pos)
magic_end = get_word(new_image, magic_end_pos)
print(f'magic start: {magic_start:#010x}')
print(f'magic end: {magic_end:#010x}')
magic_end += incsize
print(f'magic end updated: {magic_end:#010x}')
update_word(new_image, magic_end_pos, magic_end)
# Extend the image
if replace_piggy is not None:
# pad piggy if it's smaller than original
if len(replace_piggy) < piggy_size:
assert(piggy_size % 4 == 0) # ensure 4 byte aligned
replace_piggy += b'\x00' * (piggy_size - len(replace_piggy))
new_image = new_image[:piggy_start] + \
replace_piggy + \
new_image[piggy_end:]
else:
new_image = new_image[:piggy_end] + \
(b'\x00' * incsize) + \
new_image[piggy_end:]
# Write the extended image to file
with open(args.outfile, 'wb') as new_file:
new_file.write(new_image)
print('wrote new image')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment