Created
October 24, 2020 13:26
-
-
Save bwesterb/6d7b584c26cc1b716cd3cb18e70b7cde to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cairo | |
import math | |
import sys | |
class State: | |
def __init__(self, name, groupSize=1024): | |
self.surface = cairo.SVGSurface (name+".svg", 1040, 1040) | |
self.name = name | |
self.ctx = cairo.Context (self.surface) | |
self.ctx.scale(4, 4) | |
self.ctx.translate(2, 2) | |
self.x = 0. | |
self.groupSize = groupSize | |
self.idxs = list(range(256)) | |
def finish(self): | |
self.surface.finish() | |
def shuffle(self, idx, pairs): | |
bfs = [] | |
for a, b in pairs: | |
if idx == 3: | |
for i in range(8): | |
bfs.append((a*16+8+i, b*16+i)) | |
elif idx == 2: | |
for j in range(2): | |
for i in range(4): | |
bfs.append((a*16+4+i+8*j, b*16+i+8*j)) | |
elif idx == 1: | |
for j in range(4): | |
for i in range(2): | |
bfs.append((a*16+2+i+4*j, b*16+i+4*j)) | |
elif idx == 0: | |
for j in range(8): | |
bfs.append((a*16+1+2*j, b*16+2*j)) | |
else: | |
assert False | |
for i in range(len(bfs)): | |
if bfs[i][0] > bfs[i][1]: | |
bfs[i] = (bfs[i][1], bfs[i][0]) | |
self.swaps(bfs) | |
def bitflip(self, idx): | |
bfs = [] | |
for i in range(256): | |
v1 = (i & (1 << idx)) << (4 - idx) | |
v2 = ((i & (1 << 4)) >> 4) << idx | |
j = (i | (1<<4) | (1<<idx)) & (v1 | (255 ^ 16)) & (v2 | (255 ^ (1 << idx))) | |
if i >= j: | |
continue | |
bfs.append((i,j)) | |
self.swaps(bfs) | |
def butterflyX4(self, pairs, level): | |
bfs = [] | |
for a, b in pairs: | |
for i in range(16): | |
bfs.append((16*a +i, 16*b+i)) | |
self.butterflies(bfs, level) | |
def butterflies(self, bfs, level): | |
self.ctx.set_font_size(1) | |
self.ctx.set_source_rgb(0, 0, 0) | |
self.ctx.move_to(self.x, -0.5) | |
self.ctx.show_text("level %s" % level) | |
self.drawLines(bfs, False, level) | |
def swaps(self, bfs): | |
self.drawLines(bfs, True) | |
for i, j in bfs: | |
self.idxs[i], self.idxs[j] = self.idxs[j], self.idxs[i] | |
self.drawGridLines(self.x, self.x+2.0) | |
self.x -= 1 | |
self.ctx.set_font_size(.5) | |
self.ctx.set_source_rgb (.5,.5,1.) | |
for i in range(256): | |
self.ctx.move_to(self.x, i-.15) | |
self.ctx.show_text(bin(self.idxs[i])[2:].zfill(8)) | |
self.x += 2.5 | |
def drawGridLines(self, oldX, newX): | |
for i in range(0, 256): | |
if i % self.groupSize == 0: | |
self.ctx.set_source_rgb (.7,0,0) | |
else: | |
self.ctx.set_source_rgb (.7,.7,.7) | |
self.ctx.move_to(oldX, i) | |
self.ctx.line_to(newX, i) | |
self.ctx.stroke() | |
def drawLines(self, bfs, swap=False, level=None): | |
self.ctx.set_line_width(0.1) | |
occupied = [0]*256 | |
actions = [] | |
for a, b in bfs: | |
a, b = min(a,b), max(a,b) | |
placement = 0 | |
for v in range(a, b+1): | |
placement = max(placement, occupied[v]) | |
for v in range(a, b+1): | |
occupied[v] = placement+1 | |
actions.append((self.x+placement/2, a, b)) | |
oldX = self.x | |
self.x += (max(occupied)/2+1) | |
self.drawGridLines(oldX, self.x) | |
self.ctx.set_source_rgb (0,0,0) | |
for x, a, b in actions: | |
self.ctx.move_to(x, a) | |
self.ctx.line_to(x, b) | |
self.ctx.stroke() | |
if swap: | |
self.ctx.move_to(x - 0.3, a - 0.3) | |
self.ctx.line_to(x + 0.3, a + 0.3) | |
self.ctx.stroke() | |
self.ctx.move_to(x + 0.3, a - 0.3) | |
self.ctx.line_to(x - 0.3, a + 0.3) | |
self.ctx.stroke() | |
self.ctx.move_to(x - 0.3, b - 0.3) | |
self.ctx.line_to(x + 0.3, b + 0.3) | |
self.ctx.stroke() | |
self.ctx.move_to(x + 0.3, b - 0.3) | |
self.ctx.line_to(x - 0.3, b + 0.3) | |
self.ctx.stroke() | |
else: | |
self.ctx.arc(x, a, 0.2, 0, 2*math.pi) | |
self.ctx.fill() | |
self.ctx.arc(x, b, 0.2, 0, 2*math.pi) | |
self.ctx.fill() | |
prevZeta = None | |
actions.sort(key=lambda action: action[1]) | |
self.ctx.set_font_size(0.5) | |
if not swap: | |
for x, a, b in actions: | |
zeta = ((self.idxs[a]) >> (9-level)) + (1<<(level-1)) | |
if zeta != prevZeta: | |
prevZeta = zeta | |
self.ctx.move_to(x-0.25, (a+b)/2+0.3) | |
self.ctx.text_path(str(zeta)) | |
self.ctx.set_source_rgb(1, 1, 1) | |
self.ctx.set_line_width(0.2) | |
self.ctx.stroke_preserve() | |
self.ctx.set_source_rgb(0, 0, 0) | |
self.ctx.set_line_width(0.1) | |
self.ctx.fill() | |
for a, b in bfs: | |
a2 = self.idxs[a] | |
b2 = self.idxs[b] | |
if b2 != (a2 & (255 ^ (1 << (8-level)))) | (1<<(8-level)): | |
print("%s: Wrong butterfly on level %s: CT(%s, %s)" % ( | |
self.name, level, bin(a2)[2:].zfill(8), bin(b2)[2:].zfill(8))) | |
sys.exit() | |
print() | |
def ref(): | |
s = State("ref") | |
l = 256 | |
level = 1 | |
while l > 1: | |
l >>= 1 | |
offset = 0 | |
bfs = [] | |
while offset < 256-l: | |
for j in range(offset, offset+l): | |
bfs.append((j, j+l)) | |
offset += 2*l | |
s.butterflies(bfs, level) | |
level += 1 | |
s.finish() | |
def dilavx2(): | |
s = State("dilavx2", 8) | |
l = 256 | |
level = 1 | |
while l > 4: | |
l >>= 1 | |
offset = 0 | |
bfs = [] | |
while offset < 256-l: | |
for j in range(offset, offset+l): | |
bfs.append((j, j+l)) | |
offset += 2*l | |
s.butterflies(bfs, level) | |
level += 1 | |
bfs = [] | |
for i in range(32): | |
bfs.append((8*i+2, 8*i+4)) | |
bfs.append((8*i+3, 8*i+5)) | |
s.swaps(bfs) | |
bfs = [] | |
for i in range(32): | |
bfs.append((8*i, 8*i+4)) | |
bfs.append((8*i+1, 8*i+5)) | |
bfs.append((8*i+2, 8*i+6)) | |
bfs.append((8*i+3, 8*i+7)) | |
s.butterflies(bfs, 7) | |
bfs = [] | |
for i in range(32): | |
bfs.append((8*i+1, 8*i+4)) | |
bfs.append((8*i+3, 8*i+6)) | |
s.swaps(bfs) | |
bfs = [] | |
for i in range(32): | |
bfs.append((8*i, 8*i+4)) | |
bfs.append((8*i+1, 8*i+5)) | |
bfs.append((8*i+2, 8*i+6)) | |
bfs.append((8*i+3, 8*i+7)) | |
s.butterflies(bfs, 8) | |
bfs = [] | |
for i in range(32): | |
bfs.append((8*i+3, 8*i+6)) | |
bfs.append((8*i+1, 8*i+4)) | |
bfs.append((8*i+3, 8*i+5)) | |
bfs.append((8*i+2, 8*i+4)) | |
s.swaps(bfs) | |
s.finish() | |
def kybavx2(): | |
s = State("kybavx2", 16) | |
s.butterflies([(i, i+128) for i in range(128)], 1) | |
bfs = [] | |
for i in range(64): | |
bfs.append((i, i+64)) | |
bfs.append((i+128, i+192)) | |
s.butterflies(bfs, 2) | |
bfs = [] | |
for i in range(32): | |
bfs.append((i, i+32)) | |
bfs.append((i+64, i+64+32)) | |
bfs.append((i+128, i+128+32)) | |
bfs.append((i+64*3, i+64*3+32)) | |
s.butterflies(bfs, 3) | |
def finalBfs(level): | |
bfs = [] | |
for i in range(16): | |
for j in range(8): | |
bfs.append((i + 32*j, i + 32*j+16)) | |
s.butterflies(bfs, level) | |
def rev(xs): | |
return [(b,a) for (a,b) in xs] | |
shufs1 = [(0,2), (1,3), (4,6), (5,7), (8,10), (9,11), (12,14), (13,15)] | |
shufs2 = [(0,1),(2,3), (4,5), (6,7), (8,9), (10,11), (12,13), (14,15)] | |
shufs3 = shufs1 | |
shufs4 = shufs2 | |
s.shuffle(3, shufs1) | |
finalBfs(4) | |
s.shuffle(2, shufs2) | |
s.butterflyX4(shufs1, 5) | |
s.shuffle(1, shufs3) | |
s.butterflyX4(shufs2, 6) | |
s.shuffle(0, shufs4) | |
s.butterflyX4(shufs3, 7) | |
s.finish() | |
ref() | |
dilavx2() | |
kybavx2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment