# IDA (disassembler) and Hex-Rays (decompiler) plugin for Apple AMX # # WIP research. (This was edited to add more info after someone posted it to # Hacker News. Click "Revisions" to see full changes.) # # Copyright (c) 2020 dougallj # Based on Python port of VMX intrinsics plugin: # Copyright (c) 2019 w4kfu - Synacktiv # Based on AArch64 8.3-A Pointer Authentication plugin: # Copyright (c) 2018 Eloi Benoist-Vanderbeken - Synacktiv # Copyright (c) 2018 xerub # TODO: XZR can be an operand, but I don't handle that correctly in # the decompuler yet. # AMX: Apple Matrix coprocessor # # This is an undocumented arm64 ISA extension present on the Apple M1. These # instructions have been reversed from Accelerate (vImage, libBLAS, libBNNS, # libvDSP and libLAPACK all use them), and by experimenting with their # behaviour on the M1. Apple has not published a compiler, assembler, or # disassembler, but by callling into the public Accelerate framework # APIs you can get the performance benefits (fast multiplication of big # matrices). This is separate from the Apple Neural Engine. # # Warning: This is a work in progress, some of this is going to be incorrect. # # This may actually be very similar to Intel Advanced Matrix Extension (AMX), # making the name collision even more confusing, but it's not a bad place to # look for some idea of what's probably going on. # # # WIP simulator/hardware tests are at: # https://gist.github.com/dougallj/7cba721da1a94da725ee37c1e9cd1f21 # # # The coprocessor state consists of two 0x200 byte "registers", amx0 ("x") # and amx1 ("y"), and one 0x1000 byte register amx2 ("z"). (Apple headers # describe x, y, and z as register groups, where each row of 64-bytes is a # "register", and describe only "z" as being "64 registers in an M-by-N # matrix". They also describe a 64-bit AMX_STATE_T_EL1 register, which # presumably records if AMX is enabled or not, but possibly other state # too.) # # Each is typically loaded/stored from memory in rows of 0x40 bytes, # although in some operations the registers can be indexed by byte offsets. # # # AMX instructions are of the form: # # 0x00201000 | ((op & 0x1F) << 5) | (operand & 0x1F) # # AMX must be explicitly enabled using op=17, operand=0 and disabled using # op=17, operand=1. In Accelerate, these instructions are always prefixed # by three nops. What could go wrong? # # If instructions other than "enable" are executed when AMX is not enabled, # they are treated as illegal instructions. # # # All other operations (op=0-16 and op=18-22) seem to take a 64-bit register # number (X0-X30 or 31=XZR) as the operand. # # This register is typically a bitfield containing further parameters to the # operation. For example, loads and stores have a 56-bit address in bits 0 # through 55, a 5-bit register offset (in units of 0x40) in bits 56 # through 61, and a 1-bit flag in bit 62 (acting as an 0x40 byte load/store # when zero, or an 0x80 byte (but aligned) load/store when one). # # My best guess at the names is based on: # https://www.realworldtech.com/forum/?threadid=187087&curpostid=187120 # # ops 0 through 7 are loads/stores: # # 0 is load amx0 (amxldx) # 1 is load amx1 (amxldy) # 2 is store amx0 (amxstx) # 3 is store amx1 (amxsty) # 4 is load amx2 (amxldz) # 5 is store amx2 (amxstz) # # 6 and 7 load and store amx2, but in a different order, and # always as 0x40 bytes (bit 62 is ignored) # # 6 also loads amx2 (amxldzi) # 7 also stores amx2 (amxstzi) # but they use halves of two registers in amx2 # row index 0 = amx2[0].low and amx2[1].low interleaved # row index 1 = amx2[0].high and amx2[1].high interleaved # row index 2 = amx2[2].low and amx2[3].low interleaved # row index 3 = amx2[2].high and amx2[3].high interleaved # etc. # # Other operations do not touch memory, and usually have their result in # amx2 (z), but 8 and 9 have their result in amx0 and amx1 (x/y), and 22 seems # to have its result in row 0 (bytes 0 through 0x3F) of amx0. # # op8: extract row or move to x, result in amx0 (amxextrx) # # move a 64-byte row from z or y to x # operands: # x offset in bytes = (argument >> 10) & 0x1FF # z offset in rows = (argument >> 20) & 63 # move from y = argument >> 27) & 1 # if moving from y, the x offset is rounded down to 0x40 bytes (so it can only # store to a row, rather than an arbitrary byte offset in x) # # TODO: other bits # # op9: extract column or move to y, result in amx1/amx0 (amxextry) # # move a 64-byte column from z to x or y, or move a 64-byte row from x to y # # y offset in bytes = argument & 0x1FF # z offset in columns = (argument >> 20) & 63 # move from x = (argument >> 27) & 1 # # TODO: many other bits factor into how the layout and order of columns is # determined, and which register is the destination. i'd like to finish # reversing it before trying to specify it, but my current understanding # is recorded in amx_state_extry at: # # https://gist.github.com/dougallj/7cba721da1a94da725ee37c1e9cd1f21 # # op10: multiply and add 64-bit floats (amxfma64) # # similar to op14, but 8x8 double multiplies for 64 results, added # (in groups of 8) to every 8th row of register "z" (z0, z8, z16). # no "32-bit mode" flag (?) # # op11: multiply and subtract 64-bit floats (amxfms64) # # same as op10, but subtracting # # op12: multiply and add 32-bit floats (amxfma32) # # similar to op14, but 16x16 float multiplies for 256 results, added # (in groups of 16) to every 4th row of register "z" (z0, z4, z8). # no "32-bit mode" flag (?) # # op13: multiply and subtract 32-bit floats (amxfms32) # # same as op12, but subtracting # # op14: multiply and add 16-bit signed integers (amxmac16) # # input two vectors of 32 16-bit values, one from register "x" and the other # from register "y". register "z" is the output, but may also be considered an # input for "add" operations. # # each value in the first vector is multiplied with each value in the second # vector (giving 32 * 32 = 1024 results), and each result is added to the value # in register "z". (although a bit in the input register may be set to skip # the addition, and simply store the result, which is typically used on the # first iteration of an accumulating loop.) # # operands: # input offset in x (byte): (argument & 0x1FF) # input offset in y (byte): ((argument >> 10) & 0x1FF) # row offset in z: (argument >> 20) & 63 # clear "z" flag (don't add): (argument >> 27) & 1 # skip y input (don't mul): (argument >> 28) & 1 # skip x input (don't mul): (argument >> 29) & 1 # row disable: (argument >> 32) & 0x7F # col disable: (argument >> 41) & 0x7F # 32-bit mode: (argument >> 62) & 1 # vector (non-matrix) multiply add (16x16->16 in one row): (argument >> 63) & 1 # TODO: there are operands in other bits that still need to be reversed # # bit 62 makes output 32-bit ints, rather than 16-bit ints # # if bit 62 is zero, the output is in every second row, and if bit 27 is also # set, only every second row gets zeroed (old values remain in the other rows) # # row/column disable skips the operation for certain entries in the row/column: # if disable is 0: process all entries # if disable is 1: process only every second entry (starting from the index 1) # if disable is 2: process only every second entry (starting from the index 0) # if (disable & 0x60) is 0x20: process only the entry at index "ignore & 0x1F" # if (disable & 0x60) is 0x40: process only the first "ignore & 0x1F" entries # if (disable & 0x60) is 0x60: process only the last "ignore & 0x1F" entries # # for 32-bit output (sign extend all inputs to 32-bit): # z += [ # [x0, x2, x4, x6, x8, x10, x12, x14, x16, x18, x20, x22, x24, x26, x28, x30] * y0, # [x1, x3, x5, x7, x9, x11, x13, x15, x17, x19, x21, x23, x25, x27, x29, x31] * y0, # [x0, x2, x4, x6, x8, x10, x12, x14, x16, x18, x20, x22, x24, x26, x28, x30] * y1, # [x1, x3, x5, x7, x9, x11, x13, x15, x17, x19, x21, x23, x25, x27, x29, x31] * y1, # [x0, x2, x4, x6, x8, x10, x12, x14, x16, x18, x20, x22, x24, x26, x28, x30] * y2, # [x1, x3, x5, x7, x9, x11, x13, x15, x17, x19, x21, x23, x25, x27, x29, x31] * y2, # ... # ] # # note that this works well with the "store z interleaved operation" to get the values out # in order. # # for 16-bit output (although the zeroes aren't really "added" just skipped): # z += [ # [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, x22, x23, x24, x25, x26, x27, x28, x29, x30, x31] * y0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, x22, x23, x24, x25, x26, x27, x28, x29, x30, x31] * y1, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16, x17, x18, x19, x20, x21, x22, x23, x24, x25, x26, x27, x28, x29, x30, x31] * y2, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # ... # ] # # # op15: multiply and add 16-bit floats (amxfma16) # (same as op14, but for 16-bit floats) # bit 62 makes output 32-bit floats, rather than 16-bit floats # # op16: multiply and subtract 16-bit floats (amxfms16) # (same as op15, but subtracting from register "z" instead of adding) # # 17 is enable/disable # 18 does an operation, result in amx2 (vecint) # vector multiply 16-bit integers? (doesn't mac16 have a flag for this?) # z0[i] += x0[i] + y0[i] # # 19 does an operation, result in amx2 (vecfp) # vector multiply 16-bit floats? (doesn't mac16 have a flag for this?) # z0[i] += x0[i] + y0[i] # # 20 does an operation, result in amx2 (matint) # 16-bit integer matrix multiply? (doesn't fma16 do this?) # # 21 does an operation, result in amx2 (matfp) # 16-bit float matrix multiply? (doesn't fma16 do this?) # # 22 does an operation, result in amx0[0] (genlut) # # with xzr as input it takes 16 signed 32-bit integers from amx0[0] as input, # generates a 64-bit output: # [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] -> 0xffffffffffffffff # [0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] -> 0xf0 # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] -> 0xfedcba9876543210 # # # # Performance characteristics: # # AMX functions as a non-speculative coprocessor, with operations posted to it # via the store units of the CPU cores. Non-load/store AMX instructions can be # fused with other non-load/store AMX instructions, to only use one store # port. # # Because it doesn't go through the main processors out-of-order execution, it # can be beneficial to add prefetch instructions (which do) immediately before # AMX stores (and loads). # # The M1 has two AMX coprocessors, one used by the four Firestorm cores, and # another used by the four Icestorm cores. Each coprocessor stores four copies # of the architectural register state, one for each core. These access memory # via the same L2 cache as the cores, and seem to run at around the same clock # speed as the cores. Much like the power and efficiency cores, the two # coprocessors are designed differently, and have different performance # characteristics. # # The performance-variant consists of an array of 4-cycle latency, pipelined # FMAs, giving a throughput of one FMA32 or FMA64 instruction per cycle, but # only one FMA16 instruction every two cycles. # # The efficiency-variant still has 4-cycle FMA32/FMA64 latency, but it can # only perform one FMA32 or FMA64 instruction every four cycles, or one # FMA16 instruction every eight cycles. # # To achieve 1-cycle throughput from a single core, the destinations must be # independent (using a Z offset). Operations which use too much of the Z # register will have lower throughput. Throughput can also be improved using # different cores (and therefore entirely different Z registers). # # For example, an expanding 16-bit -> 32-bit FMA uses the full Z register # (for 1024 32-bit results), so has a throughput of one every four cycles, # but this can be doubled to one every two cycles by running on two # performance cores. # # There is some out-of-order execution capability on the co-processor (currently # estimated at a 28 to 32 operation buffer, with very low confidence). # # Loads and stores seem to be the bottleneck. They can generate faults on bad # addresses/alignments as expected, but appear to go via L2, with slight penalties # for data in L1. # # Mixing loads and stores on the main processor with co-processor loads and stores # causes big slowdowns, presumably due to the synchronization needed to ensure a # consistent view of memory. # # At best I've seen loads of 0x80 bytes to x and 0x80 bytes to y in 9 cycles # (running in a loop). # # Because of the out-of-order capabilities, performing four fmas fits within # this 9-cycle window at essentially no extra cost, so the following can run in # a loop at 9 cycles per iteration: # # AMX_LDX(load_addr | 0x4000000000000000); # AMX_LDY(load_addr | 0x4000000000000080); # AMX_FMA32(0x000000); # AMX_FMA32(0x110000); # AMX_FMA32(0x200040); # AMX_FMA32(0x310040); # # (this is accumulating a 32x32 tile within a larger matrix multiply) # # # # Slowdowns from mixing loads/stores: # # nop/add before AMXLDX: 9 cycles/iter # str before AMXLDX (no-aliasing): 47 cycles/iter # str before AMXLDX (aliasing): 93 or 103 cycles/iter # ldr before AMXLDX (no-"aliasing"): 11 cycles/iter # ldr before AMXLDX ("aliasing"): 66 cycles/iter # # nop/add before AMXSTX: 28 cycles/iter # str before AMXSTX (no-aliasing): 48 cycles/iter # str before AMXSTX (aliasing): 115 cycles/iter # ldr before AMXSTX (no-aliasing): 31 cycles/iter # ldr before AMXSTX (aliasing): 112 cycles/iter # # # # Hardware # # I know even less about this, but my guesses at the floorplan # locations of the AMX coprocessors are in this Twitter thread: # https://twitter.com/dougallj/status/1446097016166051848 import idaapi import ida_hexrays AMX_NONE = 0 AMX_OP0 = 1 AMX_OP1 = 2 AMX_OP2 = 3 AMX_OP3 = 4 AMX_OP4 = 5 AMX_OP5 = 6 AMX_OP6 = 7 AMX_OP7 = 8 AMX_OP8 = 9 AMX_OP9 = 10 AMX_OP10 = 11 AMX_OP11 = 12 AMX_OP12 = 13 AMX_OP13 = 14 AMX_OP14 = 15 AMX_OP15 = 16 AMX_OP16 = 17 AMX_OP17 = 18 AMX_OP18 = 19 AMX_OP19 = 20 AMX_OP20 = 21 AMX_OP21 = 22 AMX_OP22 = 23 OP_NAMES = { AMX_OP0: "AMXLDX", AMX_OP1: "AMXLDY", AMX_OP2: "AMXSTX", AMX_OP3: "AMXSTY", AMX_OP4: "AMXLDZ", AMX_OP5: "AMXSTZ", AMX_OP6: "AMXLDZI", AMX_OP7: "AMXSTZI", AMX_OP8: "AMXEXTRX", # amxextrx? AMX_OP9: "AMXEXTRY", # amxextry? AMX_OP10: "AMXFMA64", AMX_OP11: "AMXFMS64", AMX_OP12: "AMXFMA32", AMX_OP13: "AMXFMS32", AMX_OP14: "AMXMAC16", AMX_OP15: "AMXFMA16", AMX_OP16: "AMXFMS16", AMX_OP17: "AMX17", # amxset / amxclr AMX_OP18: "AMXVECINT", AMX_OP19: "AMXVECFP", AMX_OP20: "AMXMATINT", AMX_OP21: "AMXMATFP", AMX_OP22: "AMXGENLUT", } OP_INTRINSIC_NAMES = { AMX_OP0: "__amx_ldx", AMX_OP1: "__amx_ldy", AMX_OP2: "__amx_stx", AMX_OP3: "__amx_sty", AMX_OP4: "__amx_ldz", AMX_OP5: "__amx_stz", AMX_OP6: "__amx_ldzi", AMX_OP7: "__amx_stzi", AMX_OP8: "__amx_extrx", AMX_OP9: "__amx_extry", AMX_OP10: "__amx_fma64", AMX_OP11: "__amx_fms64", AMX_OP12: "__amx_fma32", AMX_OP13: "__amx_fms32", AMX_OP14: "__amx_mac16", AMX_OP15: "__amx_fma16", AMX_OP16: "__amx_fms16", AMX_OP17: "__amx_op17", # amxset / amxclr AMX_OP18: "__amx_vecint", AMX_OP19: "__amx_vecfp", AMX_OP20: "__amx_matint", AMX_OP21: "__amx_matfp", AMX_OP22: "__amx_genlut", } def decode_AMX(d, insn): if (d & 0xfffffC00) == 0x00201000: Xr = d & 31 m = (d >> 5) & 31 if m <= AMX_OP22 - AMX_OP0: #insn.itype = idaapi.ARM_nop insn.itype = idaapi.ARM_hlt insn.segpref = 14 if m == 17: insn.Op1.type = idaapi.o_imm insn.Op1.value = Xr insn.Op1.dtype = idaapi.dt_byte else: insn.Op1.type = idaapi.o_reg insn.Op1.reg = Xr + 129 insn.Op1.dtype = idaapi.dt_qword insn.insnpref = AMX_OP0 + m insn.size = 4 return True return False class Aarch64AMXHook(idaapi.IDP_Hooks): CUSTOM_INSTRUCTIONS = {idaapi.ARM_hlt} INDENT = 16 def ev_ana_insn(self, outctx): return outctx.size if decode_AMX(idaapi.get_dword(outctx.ea), outctx) else 0 def ev_emu_insn(self, insn): if insn.itype != idaapi.ARM_brk: return False return True def ev_out_mnem(self, outctx): if outctx.insn.itype in self.CUSTOM_INSTRUCTIONS: mnem = OP_NAMES.get(ord(outctx.insn.insnpref), None) if mnem is not None: outctx.out_custom_mnem(mnem, self.INDENT) return 1 return 0 class MicroInstruction(ida_hexrays.minsn_t): def __init__(self, opcode, ea): ida_hexrays.minsn_t.__init__(self, ea) self.opcode = opcode self.l.zero() self.r.zero() self.d.zero() class CallBuilder(): def __init__(self, cdg, name, return_type=idaapi.tinfo_t(idaapi.BT_VOID)): self.emitted = False self.cdg = cdg self.callinfo = ida_hexrays.mcallinfo_t() self.callinfo.callee = idaapi.BADADDR self.callinfo.solid_args = 0 self.callinfo.call_spd = 0 self.callinfo.stkargs_top = 0 self.callinfo.cc = idaapi.CM_CC_FASTCALL self.callinfo.return_type = return_type self.callinfo.flags = idaapi.FCI_SPLOK | idaapi.FCI_FINAL | idaapi.FCI_PROP self.callinfo.role = idaapi.ROLE_UNK glbhigh_off = cdg.mba.get_stack_region().off + cdg.mba.get_stack_region().size # what memory is visible to the call : GLBLOW - GLBHIGH self.callinfo.visible_memory.add(ida_hexrays.ivl_t(0x00, 0x100000)) self.callinfo.visible_memory.add(ida_hexrays.ivl_t(glbhigh_off, 0xFFFFFFFFFFFFFFFF - glbhigh_off)) # spoiled locations : GLBLOW - GLBHIGH self.callinfo.spoiled.mem.add(ida_hexrays.ivl_t(0x00, 0x100000)) self.callinfo.spoiled.mem.add(ida_hexrays.ivl_t(glbhigh_off, 0xFFFFFFFFFFFFFFFF - glbhigh_off)) self.callins = MicroInstruction(ida_hexrays.m_call, self.cdg.insn.ea) self.callins.l.make_helper(name) self.callins.d.t = ida_hexrays.mop_f self.callins.d.size = 0 self.callins.d.f = self.callinfo if (return_type.is_void()): self.ins = self.callins else: self.callins.d.size = return_type.get_size() self.ins = MicroInstruction(ida_hexrays.m_mov, self.cdg.insn.ea) self.ins.l.t = ida_hexrays.mop_d self.ins.l.d = self.callins self.ins.l.size = self.callins.d.size self.ins.d.t = ida_hexrays.mop_r self.ins.d.r = 0x00 self.ins.d.size = self.callins.d.size def add_register_argument(self, t, operand): ca = ida_hexrays.mcallarg_t() ca.t = idaapi.mop_r ca.r = operand ca.type = t ca.size = t.get_size() self.callinfo.args.push_back(ca) self.callinfo.solid_args += 1 def set_return_register(self, reg): self.ins.d.r = reg def emit(self): if self.emitted == False: self.cdg.mb.insert_into_block(self.ins, self.cdg.mb.tail) self.emitted = True class AMXFilter(ida_hexrays.microcode_filter_t): def __init__(self): ida_hexrays.microcode_filter_t.__init__(self) ida_hexrays.install_microcode_filter(self, True) def match(self, cdg): return cdg.insn.itype == idaapi.ARM_hlt and cdg.insn.insnpref != AMX_NONE def apply(self, cdg): op = ord(cdg.insn.insnpref) intrinsic_name = OP_INTRINSIC_NAMES.get(op, '__amx%d' % op) if cdg.insn.Op1.type == idaapi.o_reg: builder = CallBuilder(cdg, intrinsic_name) builder.add_register_argument(idaapi.tinfo_t(idaapi.BT_INT64 | idaapi.BTMT_UNSIGNED), cdg.load_operand(0)) builder.emit() elif cdg.insn.Op1.type == idaapi.o_imm: if op == AMX_OP17 and cdg.insn.Op1.value == 0: builder = CallBuilder(cdg, '__amx_begin') builder.emit() elif op == AMX_OP17 and cdg.insn.Op1.value == 1: builder = CallBuilder(cdg, '__amx_end') builder.emit() else: builder = CallBuilder(cdg, '%s_%d' % (intrinsic_name, cdg.insn.Op1.value)) builder.emit() return idaapi.MERR_OK class Aarch64AMXPlugin(idaapi.plugin_t): flags = idaapi.PLUGIN_PROC | idaapi.PLUGIN_HIDE comment = "Aarch64 Apple AMX extension" wanted_hotkey = "" help = "Runs transparently" wanted_name = "Aarch64 AMX" hook = None enabled = 1 def init(self): if idaapi.ph_get_id() != idaapi.PLFM_ARM or idaapi.BADADDR <= 0xFFFFFFFF: return idaapi.PLUGIN_SKIP if not ida_hexrays.init_hexrays_plugin(): print("[-] {0} : no decompiler available, skipping".format(self.wanted_name)) return idaapi.PLUGIN_SKIP print "%s init"%self.comment self.hook = Aarch64AMXHook() self.hook.hook() self.filter = AMXFilter() return idaapi.PLUGIN_KEEP def run(): pass def term(self): if self.hook is not None: self.hook.unhook() print "%s unloaded"%self.comment def PLUGIN_ENTRY(): return Aarch64AMXPlugin()