Created
July 8, 2020 21:03
-
-
Save zwegner/721063c2356b558c5b11b232baf7b805 to your computer and use it in GitHub Desktop.
Assembly generator script in Python, for making a fast wc -w
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import contextlib | |
import sys | |
# Register class, for GPRs and vector registers | |
_Reg = collections.namedtuple('Reg', 't v') | |
class Reg(_Reg): | |
def __str__(self): | |
names = ['ax', 'cx', 'dx', 'bx', 'sp', 'bp', 'si', 'di'] | |
if self.t == 'r' and self.v < 8: | |
return 'r' + names[self.v] | |
return '%s%s' % (self.t, self.v) | |
GPR = lambda i: Reg('r', i) | |
XMM = lambda i: Reg('xmm', i) | |
YMM = lambda i: Reg('ymm', i) | |
INST_ID = 0 | |
INSTS = [] | |
# Instruction creation. This create a unique ID for hashability/dependency tracking, and | |
# appends the instruction to a global list | |
class Inst: | |
def __init__(self, mnem, *args, depend=None): | |
global INST_ID | |
INST_ID += 1 | |
self.id = INST_ID | |
self.mnem = mnem | |
self.args = args | |
self.deps = {depend} if depend is not None else set() | |
INSTS.append(self) | |
def __eq__(self, other): | |
return self.id == other.id | |
def __hash__(self): | |
return self.id | |
def __str__(self): | |
return self.mnem + ' ' + ', '.join(['%#x' % a if isinstance(a, int) else str(a) for a in self.args]) | |
@contextlib.contextmanager | |
def capture(): | |
global INSTS | |
old_insts = INSTS | |
INSTS = [] | |
yield INSTS | |
INSTS = old_insts | |
# Create instruction shortcuts | |
for mnem in ['prefetcht0', 'prefetcht1', 'prefetcht2', 'prefetchnta', | |
'vmovdqu', 'vpaddb', 'vpaddq', 'vpalignr', 'vpandn', | |
'vpbroadcastb', 'vpcmpeqb', 'vpcmpgtb', 'vperm2i128', 'vpor', 'vpxor', 'vpsadbw', 'vpsubb', | |
'vmovq', 'mov', 'lea']: | |
def bind(mnem): | |
globals()[mnem] = lambda *a, **k: Inst(mnem, *a, **k) | |
bind(mnem) | |
def vzero(reg): | |
vpxor(reg, reg, reg) | |
# Pseudo-nop: this inserts a nop that affects scheduling, by yielding the instruction slot | |
# in the block-interleaving scheduler we use. This isn't actually emitted as an instruction. | |
def pseudo_nop(): | |
INSTS.append(None) | |
# Knobs | |
use_sub = 1 | |
use_index = 1 | |
use_next_base = 1 | |
scale = 7 | |
#next_base_offset = 0x1000 | |
prefetch_interleave = [2, 2] | |
prefetch_32 = [0, 0] | |
prefetch_offset = [30*32, 6*32] | |
prefetch = [prefetchnta, prefetcht0] | |
prefetch_reverse = 0 | |
prefetch_len = 2 | |
next_base_offset = prefetch_offset[0] | |
pf_params = list(zip(prefetch_interleave, prefetch, prefetch_offset, prefetch_32)) | |
if prefetch_reverse: | |
pf_params = pf_params[::-1] | |
pf_params = pf_params[:prefetch_len] | |
unroll = 8 | |
lockstep = 1 | |
schedule_break = lambda b, i, t: not t & 1 | |
schedule_break = lambda b, i, t: b & 1 and not t & 1 | |
schedule_break = lambda b, i, t: 0 | |
#schedule_break = lambda b, i, t: b & 1 | |
# Registers | |
# Constants | |
zero = YMM(1) | |
c0 = YMM(2) | |
c1 = YMM(3) | |
c2 = YMM(4) | |
# Total, subtotal | |
total = YMM(0) | |
subtotal = YMM(5) | |
# Last iteration's input as a vector register. This is | |
phi_last = last = YMM(6) | |
last_dep = None | |
# GPRs to make indexing use less encoding bytes | |
index = GPR(1) if use_index else None | |
next_base = GPR(3) if use_next_base else None | |
blocks = [] | |
# Create a pointer using scale/index/base/displacement, as controlled by various knobs | |
def ptr(offset): | |
base = 'rdi' | |
if use_next_base and offset > abs(offset - next_base_offset): | |
base = next_base | |
offset -= next_base_offset | |
d = offset >> scale | |
if not use_index or d <= 0: | |
return '[%s%+#03x]' % (base, offset) | |
# Round up scale to the next power of two | |
d = max(s for s in [1, 2, 4, 8] if s <= d) | |
return '[%s+%s*%s%+#03x]' % (base, d, index, offset - d*(1<<scale)) | |
# Base dependencies--so different blocks that use the same registers are sequenced | |
base_deps = {} | |
# Create blocks of instructions, one for each iteration of the (rolled) loop. | |
# We collect each block of instructions, and schedule them later. | |
# One loop kernel takes two registers (marginally), and since we have ~7 registers | |
# worth of constants/counters outside the main loop, we can interleave four copies | |
# of the loop in the remaining 9 registers we get in AVX2. | |
for i in range(unroll): | |
# Input pointer offset | |
offset = 0x20 * i | |
# Our two registers for this block, with some aliases | |
reg = 8+(i % 4)*2 | |
d = YMM(reg) | |
e = mask = shifted = YMM(reg+1) | |
# First sub-iteration: compute directly into the subtotal | |
if i == 0 and not use_sub: | |
mask = subtotal | |
# Insert prefetches for both cache tiers (if they're interleaved for that level) | |
for [pf_interleave, pf, pf_offset, pf_32] in pf_params: | |
if pf_interleave == 1: | |
# XXX inserting extra prefetches can help? This does two per cache line. | |
# Maybe helps the scheduler? i.e. the cpu's scheduler. Even the | |
# pseudo-nop inserted here otherwise, to keep our scheduler in sync, doesn't help | |
if pf_32 or not i & 1: | |
pf('BYTE PTR %s' % ptr(offset + pf_offset), | |
depend=base_deps.get(reg) if 1 else None) | |
elif lockstep: | |
pseudo_nop() | |
# Loop kernel | |
vmovdqu(d, 'YMMWORD PTR %s' % ptr(offset), depend=base_deps.get(reg)) | |
vpaddb(e, d, c0) | |
vpcmpgtb(e, e, c1) | |
vpcmpeqb(d, d, c2) | |
vpor(d, e, d) | |
last_dep = vperm2i128(shifted, d, last, 0x03, depend=last_dep) | |
vpalignr(shifted, d, shifted, 0xf) | |
last = d | |
base_deps[reg] = vpandn(mask, d, shifted) | |
if use_sub: | |
base_deps[reg] = vpsubb(subtotal, subtotal, mask) | |
elif i > 0: | |
base_deps[reg] = vpaddb(subtotal, subtotal, mask) | |
elif lockstep: | |
pseudo_nop() | |
# Last of the group of four blocks: move the compare result into the register | |
# the next block is expecting | |
if i & 3 == 3: | |
base_deps[reg] = vmovdqu(phi_last, last) | |
last = phi_last | |
elif lockstep: | |
pseudo_nop() | |
# Grab all current instructions for this block | |
blocks.append(INSTS) | |
INSTS = [] | |
# Schedule blocks by interleaving their instructions | |
def schedule(): | |
while True: | |
block_idx = [0] * len(blocks) | |
for [b, block] in enumerate(blocks): | |
if block: | |
inst = block[0] | |
if not inst or all(dep in INSTS for dep in inst.deps): | |
if inst: | |
INSTS.append(inst) | |
block.pop(0) | |
block_idx[b] += 1 | |
if schedule_break(b, block_idx[b], len(INSTS)): | |
break | |
if not any(blocks): | |
break | |
# Prologue | |
INSTS = [] | |
vzero(total) | |
Inst('cmp rdi,rsi') | |
Inst('jae L2') | |
vzero(zero) | |
vzero(last) | |
if use_index: | |
mov(index, 1<<scale) | |
if use_next_base: | |
lea(next_base, '[rdi+0x%x]' % next_base_offset) | |
# Single byte broadcasted constants--move to gpr then to ymm | |
BYTE_CONSTS = {} | |
for [i, [y, c]] in enumerate([[c0, 0x72], [c1, 0x7a], [c2, 0x20]]): | |
name = 'c%s' % i | |
BYTE_CONSTS[name] = c | |
vpbroadcastb(y, '[rip+%s]' % name) | |
Inst('.align 4') | |
Inst('L1:') | |
# Loop beginning: set up address registers, subtotal, prefetches | |
if use_sub: | |
vzero(subtotal) | |
# Insert prefetches for each tier if they're not interleaved | |
for [pf_interleave, pf, pf_offset, pf_32] in pf_params: | |
if not pf_interleave: | |
for offset in range(0, 32*unroll, 32 if pf_32 else 64): | |
pf('BYTE PTR %s' % ptr(offset + pf_offset)) | |
# Schedule the main loop blocks | |
with capture() as loop_insts: | |
schedule() | |
# Prefetch interleaving | |
prefetch_groups = [] | |
for [pf_interleave, pf, pf_offset, pf_32] in pf_params: | |
with capture() as prefetches: | |
if pf_interleave == 2: | |
for offset in range(0, 32*unroll, 32 if pf_32 else 64): | |
pf('BYTE PTR %s' % ptr(offset + pf_offset)) | |
prefetch_groups.append(prefetches) | |
for prefetches in prefetch_groups: | |
for [i, pf] in enumerate(prefetches): | |
i = (i * len(loop_insts)) // len(prefetches) | |
loop_insts[i:i] = [pf] | |
INSTS.extend(loop_insts) | |
# Loop end: horizontal sum of the subtotal from bytes to qwords, add into the total | |
if use_next_base: | |
lea(next_base, '[%s+0x%x]' % (next_base, unroll*32)) | |
if not use_sub: | |
vpsubb(subtotal, zero, subtotal) | |
vpsadbw(subtotal, zero, subtotal) | |
vpaddq(total, subtotal, total) | |
# Output assembly | |
with open(sys.argv[1], 'w') as f: | |
f.write(''' | |
.intel_syntax noprefix | |
.global _tokenize_zasm | |
_tokenize_zasm: | |
push rbp | |
mov rbp,rsp | |
''') | |
for inst in INSTS: | |
print(' ', inst, file=f) | |
f.write(''' | |
lea rdi,[rdi+{offset}] | |
cmp rdi,rsi | |
jb L1 | |
L2: | |
vextracti128 xmm1,{total},0x1 | |
vpaddq {total},{total},ymm1 | |
vpshufd xmm1,xmm0,0x4e | |
vpaddq {total},{total},ymm1 | |
vmovq rax,xmm0 | |
pop rbp | |
vzeroupper | |
ret | |
'''.format(offset=unroll*32, total=total)) | |
for [name, value] in BYTE_CONSTS.items(): | |
f.write('%s: .byte %s\n' % (name, value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment