Skip to content

Instantly share code, notes, and snippets.

@pawlos
Created February 13, 2020 15:30
Show Gist options
  • Save pawlos/ebf753484ff62c908bc3df60f50bae35 to your computer and use it in GitHub Desktop.
Save pawlos/ebf753484ff62c908bc3df60f50bae35 to your computer and use it in GitHub Desktop.
Solution for vv_max with emulating AVX operation with z3
from z3 import *
zero = 0
reg0 = BitVec('r0', 32*8)
reg1 = BitVec('r1', 32*8)
reg2 = BitVec('r2', 32*8)
reg3 = BitVec('r3', 32*8)
reg4 = BitVec('r4', 32*8)
reg5 = BitVec('r5', 32*8)
reg6 = BitVec('r6', 32*8)
reg7 = BitVec('r7', 32*8)
reg8 = BitVec('r8', 32*8)
reg9 = BitVec('r9', 32*8)
reg10 = BitVec('r10', 32*8)
reg11 = BitVec('r11', 32*8)
reg12 = BitVec('r12', 32*8)
reg13 = BitVec('r13', 32*8)
reg14 = BitVec('r14', 32*8)
reg15 = BitVec('r15', 32*8)
reg16 = BitVec('r16', 32*8)
reg17 = BitVec('r17', 32*8)
reg18 = BitVec('r18', 32*8)
reg19 = BitVec('r19', 32*8)
reg20 = BitVec('r20', 32*8)
reg21 = BitVec('r21', 32*8)
reg22 = BitVec('r22', 32*8)
reg23 = BitVec('r23', 32*8)
reg24 = BitVec('r24', 32*8)
reg25 = BitVec('r25', 32*8)
reg26 = BitVec('r26', 32*8)
reg27 = BitVec('r27', 32*8)
reg28 = BitVec('r28', 32*8)
reg29 = BitVec('r29', 32*8)
reg30 = BitVec('r30', 32*8)
reg31 = BitVec('r31', 32*8)
regs = [reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7 ,reg8, reg9, reg10, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0]
import re
def to_num(v):
d = v[0]
for p in v[1:]:
d = (d << 8) + p
return d
# vpermd
def perm(op1, op2):
chunks = [None]*8
for j in range(8):
i = j * 32
idx = Extract(256-32*j-6,256-32*j-8, regs[op2])
chunks[j] = simplify(If(idx == 7, Extract(1*32-1, 0, regs[op1]),
If(idx == 6, Extract(2*32-1, 1*32, regs[op1]),
If(idx == 5, Extract(3*32-1, 2*32, regs[op1]),
If(idx == 4, Extract(4*32-1, 3*32, regs[op1]),
If(idx == 3, Extract(5*32-1, 4*32, regs[op1]),
If(idx == 2, Extract(6*32-1, 5*32, regs[op1]),
If(idx == 1, Extract(7*32-1, 6*32, regs[op1]),
If(idx == 0, Extract(8*32-1, 7*32, regs[op1]),-1)))))))))
a = simplify(Concat(chunks[0],chunks[1], chunks[2],chunks[3],chunks[4],chunks[5],chunks[6],chunks[7]))
return a
# vpsrld
def shr(op1, const):
src = regs[op1]
chunks = [None]*8
for j in range(8):
i = j*32
elem = simplify(If(const > 31, BitVecVal(0, 32),
Extract(256-j*32-1, 256-(j+1)*32, src)))
elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem)))
#print (elem2)
elem3 = simplify(LShR(elem2, const))
#print (elem3)
chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3))
return simplify(Concat(chunks))
# vpslld
def shl(op1, const):
src = regs[op1]
chunks = [None]*8
for j in range(8):
i = j*32
elem = simplify(If(const > 31, BitVecVal(0, 32),
Extract(256-j*32-1, 256-(j+1)*32, src)))
elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem)))
elem3 = simplify(elem2 << const)
chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3))
return simplify(Concat(chunks))
# vpxor
def xor(op1, op2):
return simplify(regs[op1] ^ regs[op2])
# vpand
def _and(op1, op2):
return simplify(regs[op1] & regs[op2])
# vpor
def _or(op1, op2):
return simplify(regs[op1] | regs[op2])
# vpcmpeqb
def cmp(op1, op2):
chunksA = [None]*32
chunksB = [None]*32
chunksC = [None]*32
a = regs[op1]
b = regs[op2]
for j in range(32):
chunksA[j] = simplify(Extract((j+1)*8-1, j*8, a))
chunksB[j] = simplify(Extract((j+1)*8-1, j*8, b))
for j in range(32):
chunksC[j] = If(simplify(chunksA[j] == chunksB[j]), BitVecVal(0xFF, 8), BitVecVal(0, 8))
return simplify(Concat(chunksC))#[::-1]
def to_dword(v):
return simplify(Concat(Extract(7,0, v),Extract(15,8, v),Extract(23,16, v),Extract(31,24, v)))
def from_dword(v):
return Concat(Extract(7,0, v), Extract(15,8, v), Extract(23,16, v), Extract(31,24, v))
# vpaddd
def add_dwords(op1, op2):
src1 = regs[op1]
chunksA = [None]*8
chunksB = [None]*8
chunksA[0] = to_dword(simplify(Extract(1*32-1, 0*32, src1)))
chunksA[1] = to_dword(simplify(Extract(2*32-1, 1*32, src1)))
chunksA[2] = to_dword(simplify(Extract(3*32-1, 2*32, src1)))
chunksA[3] = to_dword(simplify(Extract(4*32-1, 3*32, src1)))
chunksA[4] = to_dword(simplify(Extract(5*32-1, 4*32, src1)))
chunksA[5] = to_dword(simplify(Extract(6*32-1, 5*32, src1)))
chunksA[6] = to_dword(simplify(Extract(7*32-1, 6*32, src1)))
chunksA[7] = to_dword(simplify(Extract(8*32-1, 7*32, src1)))
src2 = regs[op2]
chunksB[0] = to_dword(simplify(Extract(1*32-1, 0*32, src2)))
chunksB[1] = to_dword(simplify(Extract(2*32-1, 1*32, src2)))
chunksB[2] = to_dword(simplify(Extract(3*32-1, 2*32, src2)))
chunksB[3] = to_dword(simplify(Extract(4*32-1, 3*32, src2)))
chunksB[4] = to_dword(simplify(Extract(5*32-1, 4*32, src2)))
chunksB[5] = to_dword(simplify(Extract(6*32-1, 5*32, src2)))
chunksB[6] = to_dword(simplify(Extract(7*32-1, 6*32, src2)))
chunksB[7] = to_dword(simplify(Extract(8*32-1, 7*32, src2)))
result = []
for i in range(len(chunksA)):
result.append(simplify(from_dword(chunksA[i] + chunksB[i])))
return simplify(Concat(result[::-1]))
# vpaddb
def add_bytes(op1, op2):
a = regs[op1]
b = regs[op2]
chunks = [None]*32
for j in range(32):
i = j * 8
chunks[j] = simplify(Extract(i+7, i, a) + Extract(i+7,i, b))
return simplify(Concat(chunks[::-1]))
# vpshufb
def shuff(op1, op2):
a = regs[op1]
b = regs[op2]
destLow = [None]*16
destHi = [None]*16
for j in range(16):
i = j*8
idx = simplify(Extract(256-8*j-5,256-8*j-8,b))
off = 256
destLow[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF,
BitVecVal(0, 8),
simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a),
If(idx == 1, Extract(off-1*8-1, off-2*8, a),
If(idx == 2, Extract(off-2*8-1, off-3*8, a),
If(idx == 3, Extract(off-3*8-1, off-4*8, a),
If(idx == 4, Extract(off-4*8-1, off-5*8, a),
If(idx == 5, Extract(off-5*8-1, off-6*8, a),
If(idx == 6, Extract(off-6*8-1, off-7*8, a),
If(idx == 7, Extract(off-7*8-1, off-8*8, a),
If(idx == 8, Extract(off-8*8-1, off-9*8, a),
If(idx == 9, Extract(off-9*8-1,off- 10*8, a),
If(idx == 10, Extract(off-10*8-1,off-11*8, a),
If(idx == 11, Extract(off-11*8-1,off- 12*8, a),
If(idx == 12, Extract(off-12*8-1,off- 13*8, a),
If(idx == 13, Extract(off-13*8-1,off- 14*8, a),
If(idx == 14, Extract(off-14*8-1,off- 15*8, a),
If(idx == 15, Extract(off-15*8-1,off- 16*8,a), BitVecVal(0,8)))))))))))))))))))
)
idx = simplify(Extract(128-8*j-5, 128-8*j-8, b))
off = 128
destHi[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF,
BitVecVal(0, 8),
simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a),
If(idx == 1, Extract(off-1*8-1, off-2*8, a),
If(idx == 2, Extract(off-2*8-1, off-3*8, a),
If(idx == 3, Extract(off-3*8-1, off-4*8, a),
If(idx == 4, Extract(off-4*8-1, off-5*8, a),
If(idx == 5, Extract(off-5*8-1, off-6*8, a),
If(idx == 6, Extract(off-6*8-1, off-7*8, a),
If(idx == 7, Extract(off-7*8-1, off-8*8, a),
If(idx == 8, Extract(off-8*8-1, off-9*8, a),
If(idx == 9, Extract(off-9*8-1, off-10*8, a),
If(idx == 10, Extract(off-10*8-1,off-11*8, a),
If(idx == 11, Extract(off-11*8-1,off-12*8, a),
If(idx == 12, Extract(off-12*8-1, off-13*8, a),
If(idx == 13, Extract(off-13*8-1, off-14*8, a),
If(idx == 14, Extract(off-14*8-1, off-15*8, a),
If(idx == 15, Extract(off-15*8-1, off-16*8,a), BitVecVal(0,8)))))))))))))))))))
)
res = []
for i in range(16):
res.append(simplify(destLow[i]))
for i in range(16):
res.append(simplify(destHi[i]))
return Concat(res)
def mul_add8(op1, op2):
res = []
a = regs[op1]
b = regs[op2]
for j in range(16):
i = (j+1)*16
aHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, a)))
bHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, b)))
aLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,a)))
bLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,b)))
c = aHi*bHi + aLo*bLo
v = to_16bit(simplify(c))
res.append(v)
o = Concat(res)
return simplify(o)
def to_16bit(v):
return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v)))
def to_32bit(v):
return to_dword(v)
def from_16bit(v):
return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v)))
def mul_add16(op1, op2):
res = []
a = regs[op1]
b = regs[op2]
for j in range(8):
i = (j+1)*32
aHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, a))))
bHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, b))))
aLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,a))))
bLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,b))))
c = aHi*bHi + aLo*bLo
v = to_32bit(simplify(c))
res.append(v)
return simplify(Concat(res))
s = Solver()
def write_before(txt, r1, r2, a):
if a():
print (txt)
print (regs[r1])
print (regs[r2])
def write_after(txt, r, a):
if a():
print (txt)
print(regs[r])
sys.exit(-1)
import sys
printArgs = False
fileName = sys.argv[1]
print ('Opening file: '+fileName)
with open(fileName) as f:
while True:
line = f.readline().strip()
if printArgs:
printArgs = False
if 'r2 =' in line:
printArgs = True
if not all([x.size() == 256 for x in regs]):
print ('something not right!')
sys.exit(-1)
if line == '':
break
m = re.match('^r(\\d{1,2}) = (\\[.+\\])$', line)
if m:
r = int(m.group(1))
v = eval(m.group(2).strip())
if (r != 1):
p = to_num(v)
regs[r] = BitVecVal(p, 32*8)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2})$', line)
if m:
print ('Should never match!')
r1 = int(m.group(1))
r2 = int(m.group(2))
regs[r1] = regs[r2]
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) perm r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = perm(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) >> (\\d+)$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
c = int(m.group(3))
regs[r1] = shr(r2, c)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\^ r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = xor(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) & r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = _and(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) << (\\d+)$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
c = int(m.group(3))
regs[r1] = shl(r2, c)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\| r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = _or(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) == r(\\d{1,2});.*$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = cmp(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;dwords.*$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = add_dwords(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;bytes.*$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = add_bytes(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) shuff r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = shuff(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add8 r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = mul_add8(r2, r3)
continue
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add16 r(\\d{1,2})$', line)
if m:
r1 = int(m.group(1))
r2 = int(m.group(2))
r3 = int(m.group(3))
regs[r1] = mul_add16(r2, r3)
continue
print('Unrecognized line: '+line)
import sys
sys.exit(-1)
x = [0x70,0x70,0xB2,0xAC,0x01,0xD2,0x5E,0x61,0x0A,0xA7,0x2A,0xA8,0x08,0x1C,0x86,0x1A,0xE8,0x45,0xC8,0x29,0xB2,0xF3,0xA1,0x1E,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00]
x = BitVecVal(to_num(x), 32*8)
s.add(x == regs[2])
print (simplify(regs[2]))
for i in range(32):
c = Extract(i*8+7, i*8, regs[1])
s.add(And(c > 0x30, c <= 0x7a))
for i in range(32):
c = Extract(i*8+7, i*8, regs[1]) ^ Extract(i*8+7, i*8, regs[31])
s.add(Or(
And(c >= 0x30, c <= 0x39),
Or(
And(c >= 0x41, c <= 0x5a),
Or(
And(c >= 0x61, c <= 0x7a), Or(c == 0x5f, c == 0x0)))))
print (s.check())
r = s.model()
print (r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment