Created
February 13, 2020 15:30
-
-
Save pawlos/ebf753484ff62c908bc3df60f50bae35 to your computer and use it in GitHub Desktop.
Solution for vv_max with emulating AVX operation with z3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from z3 import * | |
zero = 0 | |
reg0 = BitVec('r0', 32*8) | |
reg1 = BitVec('r1', 32*8) | |
reg2 = BitVec('r2', 32*8) | |
reg3 = BitVec('r3', 32*8) | |
reg4 = BitVec('r4', 32*8) | |
reg5 = BitVec('r5', 32*8) | |
reg6 = BitVec('r6', 32*8) | |
reg7 = BitVec('r7', 32*8) | |
reg8 = BitVec('r8', 32*8) | |
reg9 = BitVec('r9', 32*8) | |
reg10 = BitVec('r10', 32*8) | |
reg11 = BitVec('r11', 32*8) | |
reg12 = BitVec('r12', 32*8) | |
reg13 = BitVec('r13', 32*8) | |
reg14 = BitVec('r14', 32*8) | |
reg15 = BitVec('r15', 32*8) | |
reg16 = BitVec('r16', 32*8) | |
reg17 = BitVec('r17', 32*8) | |
reg18 = BitVec('r18', 32*8) | |
reg19 = BitVec('r19', 32*8) | |
reg20 = BitVec('r20', 32*8) | |
reg21 = BitVec('r21', 32*8) | |
reg22 = BitVec('r22', 32*8) | |
reg23 = BitVec('r23', 32*8) | |
reg24 = BitVec('r24', 32*8) | |
reg25 = BitVec('r25', 32*8) | |
reg26 = BitVec('r26', 32*8) | |
reg27 = BitVec('r27', 32*8) | |
reg28 = BitVec('r28', 32*8) | |
reg29 = BitVec('r29', 32*8) | |
reg30 = BitVec('r30', 32*8) | |
reg31 = BitVec('r31', 32*8) | |
regs = [reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7 ,reg8, reg9, reg10, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0, reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0,reg0] | |
import re | |
def to_num(v): | |
d = v[0] | |
for p in v[1:]: | |
d = (d << 8) + p | |
return d | |
# vpermd | |
def perm(op1, op2): | |
chunks = [None]*8 | |
for j in range(8): | |
i = j * 32 | |
idx = Extract(256-32*j-6,256-32*j-8, regs[op2]) | |
chunks[j] = simplify(If(idx == 7, Extract(1*32-1, 0, regs[op1]), | |
If(idx == 6, Extract(2*32-1, 1*32, regs[op1]), | |
If(idx == 5, Extract(3*32-1, 2*32, regs[op1]), | |
If(idx == 4, Extract(4*32-1, 3*32, regs[op1]), | |
If(idx == 3, Extract(5*32-1, 4*32, regs[op1]), | |
If(idx == 2, Extract(6*32-1, 5*32, regs[op1]), | |
If(idx == 1, Extract(7*32-1, 6*32, regs[op1]), | |
If(idx == 0, Extract(8*32-1, 7*32, regs[op1]),-1))))))))) | |
a = simplify(Concat(chunks[0],chunks[1], chunks[2],chunks[3],chunks[4],chunks[5],chunks[6],chunks[7])) | |
return a | |
# vpsrld | |
def shr(op1, const): | |
src = regs[op1] | |
chunks = [None]*8 | |
for j in range(8): | |
i = j*32 | |
elem = simplify(If(const > 31, BitVecVal(0, 32), | |
Extract(256-j*32-1, 256-(j+1)*32, src))) | |
elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem))) | |
#print (elem2) | |
elem3 = simplify(LShR(elem2, const)) | |
#print (elem3) | |
chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3)) | |
return simplify(Concat(chunks)) | |
# vpslld | |
def shl(op1, const): | |
src = regs[op1] | |
chunks = [None]*8 | |
for j in range(8): | |
i = j*32 | |
elem = simplify(If(const > 31, BitVecVal(0, 32), | |
Extract(256-j*32-1, 256-(j+1)*32, src))) | |
elem2 = simplify(Concat(Extract(7,0, elem),Extract(15,8, elem),Extract(23,16, elem),Extract(31,24, elem))) | |
elem3 = simplify(elem2 << const) | |
chunks[j] = Concat(Extract(7,0, elem3), Extract(15,8, elem3), Extract(23,16, elem3), Extract(31,24, elem3)) | |
return simplify(Concat(chunks)) | |
# vpxor | |
def xor(op1, op2): | |
return simplify(regs[op1] ^ regs[op2]) | |
# vpand | |
def _and(op1, op2): | |
return simplify(regs[op1] & regs[op2]) | |
# vpor | |
def _or(op1, op2): | |
return simplify(regs[op1] | regs[op2]) | |
# vpcmpeqb | |
def cmp(op1, op2): | |
chunksA = [None]*32 | |
chunksB = [None]*32 | |
chunksC = [None]*32 | |
a = regs[op1] | |
b = regs[op2] | |
for j in range(32): | |
chunksA[j] = simplify(Extract((j+1)*8-1, j*8, a)) | |
chunksB[j] = simplify(Extract((j+1)*8-1, j*8, b)) | |
for j in range(32): | |
chunksC[j] = If(simplify(chunksA[j] == chunksB[j]), BitVecVal(0xFF, 8), BitVecVal(0, 8)) | |
return simplify(Concat(chunksC))#[::-1] | |
def to_dword(v): | |
return simplify(Concat(Extract(7,0, v),Extract(15,8, v),Extract(23,16, v),Extract(31,24, v))) | |
def from_dword(v): | |
return Concat(Extract(7,0, v), Extract(15,8, v), Extract(23,16, v), Extract(31,24, v)) | |
# vpaddd | |
def add_dwords(op1, op2): | |
src1 = regs[op1] | |
chunksA = [None]*8 | |
chunksB = [None]*8 | |
chunksA[0] = to_dword(simplify(Extract(1*32-1, 0*32, src1))) | |
chunksA[1] = to_dword(simplify(Extract(2*32-1, 1*32, src1))) | |
chunksA[2] = to_dword(simplify(Extract(3*32-1, 2*32, src1))) | |
chunksA[3] = to_dword(simplify(Extract(4*32-1, 3*32, src1))) | |
chunksA[4] = to_dword(simplify(Extract(5*32-1, 4*32, src1))) | |
chunksA[5] = to_dword(simplify(Extract(6*32-1, 5*32, src1))) | |
chunksA[6] = to_dword(simplify(Extract(7*32-1, 6*32, src1))) | |
chunksA[7] = to_dword(simplify(Extract(8*32-1, 7*32, src1))) | |
src2 = regs[op2] | |
chunksB[0] = to_dword(simplify(Extract(1*32-1, 0*32, src2))) | |
chunksB[1] = to_dword(simplify(Extract(2*32-1, 1*32, src2))) | |
chunksB[2] = to_dword(simplify(Extract(3*32-1, 2*32, src2))) | |
chunksB[3] = to_dword(simplify(Extract(4*32-1, 3*32, src2))) | |
chunksB[4] = to_dword(simplify(Extract(5*32-1, 4*32, src2))) | |
chunksB[5] = to_dword(simplify(Extract(6*32-1, 5*32, src2))) | |
chunksB[6] = to_dword(simplify(Extract(7*32-1, 6*32, src2))) | |
chunksB[7] = to_dword(simplify(Extract(8*32-1, 7*32, src2))) | |
result = [] | |
for i in range(len(chunksA)): | |
result.append(simplify(from_dword(chunksA[i] + chunksB[i]))) | |
return simplify(Concat(result[::-1])) | |
# vpaddb | |
def add_bytes(op1, op2): | |
a = regs[op1] | |
b = regs[op2] | |
chunks = [None]*32 | |
for j in range(32): | |
i = j * 8 | |
chunks[j] = simplify(Extract(i+7, i, a) + Extract(i+7,i, b)) | |
return simplify(Concat(chunks[::-1])) | |
# vpshufb | |
def shuff(op1, op2): | |
a = regs[op1] | |
b = regs[op2] | |
destLow = [None]*16 | |
destHi = [None]*16 | |
for j in range(16): | |
i = j*8 | |
idx = simplify(Extract(256-8*j-5,256-8*j-8,b)) | |
off = 256 | |
destLow[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF, | |
BitVecVal(0, 8), | |
simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a), | |
If(idx == 1, Extract(off-1*8-1, off-2*8, a), | |
If(idx == 2, Extract(off-2*8-1, off-3*8, a), | |
If(idx == 3, Extract(off-3*8-1, off-4*8, a), | |
If(idx == 4, Extract(off-4*8-1, off-5*8, a), | |
If(idx == 5, Extract(off-5*8-1, off-6*8, a), | |
If(idx == 6, Extract(off-6*8-1, off-7*8, a), | |
If(idx == 7, Extract(off-7*8-1, off-8*8, a), | |
If(idx == 8, Extract(off-8*8-1, off-9*8, a), | |
If(idx == 9, Extract(off-9*8-1,off- 10*8, a), | |
If(idx == 10, Extract(off-10*8-1,off-11*8, a), | |
If(idx == 11, Extract(off-11*8-1,off- 12*8, a), | |
If(idx == 12, Extract(off-12*8-1,off- 13*8, a), | |
If(idx == 13, Extract(off-13*8-1,off- 14*8, a), | |
If(idx == 14, Extract(off-14*8-1,off- 15*8, a), | |
If(idx == 15, Extract(off-15*8-1,off- 16*8,a), BitVecVal(0,8))))))))))))))))))) | |
) | |
idx = simplify(Extract(128-8*j-5, 128-8*j-8, b)) | |
off = 128 | |
destHi[j] = simplify(If(simplify(Extract(off-8*j-1,off-8*(j+1),b)) == 0xF, | |
BitVecVal(0, 8), | |
simplify(If(idx == 0, Extract(off-0*8-1, off-1*8, a), | |
If(idx == 1, Extract(off-1*8-1, off-2*8, a), | |
If(idx == 2, Extract(off-2*8-1, off-3*8, a), | |
If(idx == 3, Extract(off-3*8-1, off-4*8, a), | |
If(idx == 4, Extract(off-4*8-1, off-5*8, a), | |
If(idx == 5, Extract(off-5*8-1, off-6*8, a), | |
If(idx == 6, Extract(off-6*8-1, off-7*8, a), | |
If(idx == 7, Extract(off-7*8-1, off-8*8, a), | |
If(idx == 8, Extract(off-8*8-1, off-9*8, a), | |
If(idx == 9, Extract(off-9*8-1, off-10*8, a), | |
If(idx == 10, Extract(off-10*8-1,off-11*8, a), | |
If(idx == 11, Extract(off-11*8-1,off-12*8, a), | |
If(idx == 12, Extract(off-12*8-1, off-13*8, a), | |
If(idx == 13, Extract(off-13*8-1, off-14*8, a), | |
If(idx == 14, Extract(off-14*8-1, off-15*8, a), | |
If(idx == 15, Extract(off-15*8-1, off-16*8,a), BitVecVal(0,8))))))))))))))))))) | |
) | |
res = [] | |
for i in range(16): | |
res.append(simplify(destLow[i])) | |
for i in range(16): | |
res.append(simplify(destHi[i])) | |
return Concat(res) | |
def mul_add8(op1, op2): | |
res = [] | |
a = regs[op1] | |
b = regs[op2] | |
for j in range(16): | |
i = (j+1)*16 | |
aHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, a))) | |
bHi = simplify(ZeroExt(8, Extract(256-i+15,256-i+8, b))) | |
aLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,a))) | |
bLo = simplify(ZeroExt(8, Extract(256-i+7,256-i,b))) | |
c = aHi*bHi + aLo*bLo | |
v = to_16bit(simplify(c)) | |
res.append(v) | |
o = Concat(res) | |
return simplify(o) | |
def to_16bit(v): | |
return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v))) | |
def to_32bit(v): | |
return to_dword(v) | |
def from_16bit(v): | |
return simplify(Concat(Extract(7, 0, v), Extract(15, 8, v))) | |
def mul_add16(op1, op2): | |
res = [] | |
a = regs[op1] | |
b = regs[op2] | |
for j in range(8): | |
i = (j+1)*32 | |
aHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, a)))) | |
bHi = simplify(ZeroExt(16, from_16bit(Extract(256-i+31,256-i+16, b)))) | |
aLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,a)))) | |
bLo = simplify(ZeroExt(16, from_16bit(Extract(256-i+15,256-i,b)))) | |
c = aHi*bHi + aLo*bLo | |
v = to_32bit(simplify(c)) | |
res.append(v) | |
return simplify(Concat(res)) | |
s = Solver() | |
def write_before(txt, r1, r2, a): | |
if a(): | |
print (txt) | |
print (regs[r1]) | |
print (regs[r2]) | |
def write_after(txt, r, a): | |
if a(): | |
print (txt) | |
print(regs[r]) | |
sys.exit(-1) | |
import sys | |
printArgs = False | |
fileName = sys.argv[1] | |
print ('Opening file: '+fileName) | |
with open(fileName) as f: | |
while True: | |
line = f.readline().strip() | |
if printArgs: | |
printArgs = False | |
if 'r2 =' in line: | |
printArgs = True | |
if not all([x.size() == 256 for x in regs]): | |
print ('something not right!') | |
sys.exit(-1) | |
if line == '': | |
break | |
m = re.match('^r(\\d{1,2}) = (\\[.+\\])$', line) | |
if m: | |
r = int(m.group(1)) | |
v = eval(m.group(2).strip()) | |
if (r != 1): | |
p = to_num(v) | |
regs[r] = BitVecVal(p, 32*8) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2})$', line) | |
if m: | |
print ('Should never match!') | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
regs[r1] = regs[r2] | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) perm r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = perm(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) >> (\\d+)$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
c = int(m.group(3)) | |
regs[r1] = shr(r2, c) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\^ r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = xor(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) & r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = _and(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) << (\\d+)$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
c = int(m.group(3)) | |
regs[r1] = shl(r2, c) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\| r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = _or(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) == r(\\d{1,2});.*$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = cmp(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;dwords.*$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = add_dwords(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) \\+ r(\\d{1,2}) ;bytes.*$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = add_bytes(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) shuff r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = shuff(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add8 r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = mul_add8(r2, r3) | |
continue | |
m = re.match('^r(\\d{1,2}) = r(\\d{1,2}) mul_add16 r(\\d{1,2})$', line) | |
if m: | |
r1 = int(m.group(1)) | |
r2 = int(m.group(2)) | |
r3 = int(m.group(3)) | |
regs[r1] = mul_add16(r2, r3) | |
continue | |
print('Unrecognized line: '+line) | |
import sys | |
sys.exit(-1) | |
x = [0x70,0x70,0xB2,0xAC,0x01,0xD2,0x5E,0x61,0x0A,0xA7,0x2A,0xA8,0x08,0x1C,0x86,0x1A,0xE8,0x45,0xC8,0x29,0xB2,0xF3,0xA1,0x1E,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00] | |
x = BitVecVal(to_num(x), 32*8) | |
s.add(x == regs[2]) | |
print (simplify(regs[2])) | |
for i in range(32): | |
c = Extract(i*8+7, i*8, regs[1]) | |
s.add(And(c > 0x30, c <= 0x7a)) | |
for i in range(32): | |
c = Extract(i*8+7, i*8, regs[1]) ^ Extract(i*8+7, i*8, regs[31]) | |
s.add(Or( | |
And(c >= 0x30, c <= 0x39), | |
Or( | |
And(c >= 0x41, c <= 0x5a), | |
Or( | |
And(c >= 0x61, c <= 0x7a), Or(c == 0x5f, c == 0x0))))) | |
print (s.check()) | |
r = s.model() | |
print (r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment