-
-
Save markcheno/dbed90c0cfb5f38ca5a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import math | |
import random | |
import struct | |
import numpy | |
from deap import algorithms | |
from deap import base | |
from deap import creator | |
from deap import tools | |
from deap import gp | |
from ctypes import * | |
# This is obtained by doing a file "test.c" containing: | |
# double add(double x, double y) { return x + y; } | |
# and executing: | |
# gcc -c -O3 ./test.c | |
# objdump -S ./test.o | |
add_bytecode = b''.join([ # 0000000000000000 <add>: | |
b"\xf2", b"\x0f", b"\x58", b"\xc1", # addsd %xmm1,%xmm0 | |
b"\xc3", # retq | |
]) | |
sub_bytecode = b''.join([ # 0000000000000000 <sub>: | |
b"\xf2\x0f\x5c\xc1", # subsd %xmm1,%xmm0 | |
b"\xc3", # retq | |
]) | |
mul_bytecode = b''.join([ # 0000000000000000 <mul>: | |
b"\xf2\x0f\x59\xc1", # mulsd %xmm1,%xmm0 | |
b"\xc3", # retq | |
]) | |
neg_bytecode = b''.join([ # 0000000000000000 <neg>: | |
b"\xf2\x0f\x10\x0d\x05\x00\x00\x00", # movsd 0x05(%rip),%xmm1 | |
b"\x66\x0f\x57\xc1", # xorpd %xmm1,%xmm0 | |
b"\xc3", # retq | |
b"\x00\x00\x00\x00\x00\x00\x00\x80", # data: double with only sign flag activated | |
]) | |
libc = CDLL("libc.so.6") | |
# Some constants | |
PROT_READ = 1 | |
PROT_WRITE = 2 | |
PROT_EXEC = 4 | |
def executable_code(buff): | |
"""Return a pointer to a page-aligned executable buffer filled in with the data of the string provided. | |
The pointer should be freed with libc.free() when finished""" | |
buf = c_char_p(buff) | |
size = len(buff) | |
# Need to align to a page boundary, so use valloc | |
addr = libc.valloc(size) | |
addr = c_void_p(addr) | |
if 0 == addr: | |
raise Exception("Failed to allocate memory") | |
memmove(addr, buf, size) | |
if 0 != libc.mprotect(addr, len(buff), PROT_READ | PROT_WRITE | PROT_EXEC): | |
raise Exception("Failed to set protection on buffer") | |
return addr | |
add_code_ptr = executable_code(add_bytecode) | |
myAdd = cast(add_code_ptr, CFUNCTYPE(c_double, c_double, c_double)) | |
sub_code_ptr = executable_code(sub_bytecode) | |
mySub = cast(sub_code_ptr, CFUNCTYPE(c_double, c_double, c_double)) | |
mul_code_ptr = executable_code(mul_bytecode) | |
myMul = cast(mul_code_ptr, CFUNCTYPE(c_double, c_double, c_double)) | |
neg_code_ptr = executable_code(neg_bytecode) | |
myNeg = cast(neg_code_ptr, CFUNCTYPE(c_double, c_double)) | |
myAdd.__name__ = 'add' | |
mySub.__name__ = 'sub' | |
myMul.__name__ = 'mul' | |
myNeg.__name__ = 'neg' | |
primitives_bytecode = [ | |
('add', add_bytecode), | |
('sub', sub_bytecode), | |
('mul', mul_bytecode), | |
('neg', neg_bytecode), | |
] | |
primitives_addr = {} | |
dist = 0 | |
for k, v in primitives_bytecode: | |
primitives_addr[k] = dist | |
dist += len(v) | |
pset = gp.PrimitiveSet("MAIN", 1) | |
pset.addPrimitive(myAdd, 2) | |
pset.addPrimitive(mySub, 2) | |
pset.addPrimitive(myMul, 2) | |
pset.addPrimitive(myNeg, 1) | |
pset.addEphemeralConstant("rand101", lambda: random.randint(-1,1)) | |
pset.renameArguments(ARG0='x') | |
creator.create("FitnessMin", base.Fitness, weights=(-1.0,)) | |
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin) | |
toolbox = base.Toolbox() | |
toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=2) | |
toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr) | |
toolbox.register("population", tools.initRepeat, list, toolbox.individual) | |
toolbox.register("compile", gp.compile, pset=pset) | |
def buildASM(individual): | |
CALL = b"\xE8" | |
MOV_XMM0toXMM15 = b"\xF3\x44\x0F\x7E\xF8" | |
MOV_XMM15toXMM0 = b"\xF3\x41\x0F\x7E\xC7" | |
PUSH_XMM0 = b"\x66\x48\x0F\x7E\xC0\x50" | |
PUSH_IMM_PRE = b"\x48\xB8" | |
PUSH_IMM_POST = b"\x50" | |
PUSH_XMM15 = b"\x66\x4C\x0F\x7E\xF8\x50" | |
MOV_IMM_PRE = b"\x48\xB8" | |
MOV_IMM_POST = b"\x66\x48\x0F\x6E\xC0" | |
POP_XMM0 = b"\x58\x66\x48\x0F\x6E\xC0" | |
POP_XMM1 = b"\x58\x66\x48\x0F\x6E\xC8" | |
POP_XMM2 = b"\x58\x66\x48\x0F\x6E\xD0" | |
POP_XMM3 = b"\x58\x66\x48\x0F\x6E\xD8" | |
POP_XMM4 = b"\x58\x66\x48\x0F\x6E\xE0" | |
RET = b"\xc3" | |
bcode = [RET] | |
nopush = True | |
for node in individual: | |
if node.arity > 1: | |
if not nopush: | |
bcode.append(PUSH_XMM0) | |
bcode.append(CALL + struct.pack('i', len(b"".join(bcode)) + primitives_addr[node.name])) | |
for reg, _ in zip([POP_XMM1, POP_XMM2, POP_XMM3], list(range(node.arity - 1))): | |
bcode.append(reg) | |
nopush = True | |
else: | |
if not hasattr(node, 'value'): | |
# Is a function | |
if not nopush: | |
bcode.append(PUSH_XMM0) | |
bcode.append(CALL + struct.pack('i', len(b"".join(bcode)) + primitives_addr[node.name])) | |
nopush = True | |
elif isinstance(node.value, str): | |
# Is an argument | |
if nopush: | |
bcode.append(MOV_XMM15toXMM0) | |
else: | |
bcode.append(PUSH_XMM15) | |
nopush = False | |
else: | |
# Is an immediate value | |
if nopush: | |
bcode.append(MOV_IMM_PRE + struct.pack('d', node.value) + MOV_IMM_POST) | |
else: | |
bcode.append(PUSH_IMM_PRE + struct.pack('d', node.value) + PUSH_IMM_POST) | |
nopush = False | |
bcode.append(MOV_XMM0toXMM15) | |
bcode.reverse() | |
bcode.append(b"".join(list(zip(*primitives_bytecode))[1])) | |
return (b"".join(bcode)) | |
def evalSymbReg(individual, points): | |
# Transform the tree expression in a callable function | |
ind_code = buildASM(individual) | |
ind_code_ptr = executable_code(ind_code) | |
this_ind_fct = cast(ind_code_ptr, CFUNCTYPE(c_double, c_double)) | |
# Evaluate the mean squared error between the expression | |
# and the real function : x**4 + x**3 + x**2 + x | |
sqerrors = ((this_ind_fct(x) - x**4 - x**3 - x**2 - x)**2 for x in points) | |
return math.fsum(sqerrors) / len(points), | |
toolbox.register("evaluate", evalSymbReg, points=[x/10. for x in range(-10,10)]) | |
toolbox.register("select", tools.selTournament, tournsize=3) | |
toolbox.register("mate", gp.cxOnePoint) | |
toolbox.register("expr_mut", gp.genFull, min_=0, max_=2) | |
toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr_mut, pset=pset) | |
def main(): | |
random.seed(318) | |
pop = toolbox.population(n=300) | |
hof = tools.HallOfFame(1) | |
stats_fit = tools.Statistics(lambda ind: ind.fitness.values) | |
stats_size = tools.Statistics(len) | |
mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size) | |
mstats.register("avg", numpy.mean) | |
mstats.register("std", numpy.std) | |
mstats.register("min", numpy.min) | |
mstats.register("max", numpy.max) | |
pop, log = algorithms.eaSimple(pop, toolbox, 0.5, 0.1, 40, stats=mstats, | |
halloffame=hof, verbose=True) | |
# print log | |
return pop, log, hof | |
if __name__ == "__main__": | |
main() | |
libc.free(add_code_ptr) | |
libc.free(sub_code_ptr) | |
libc.free(mul_code_ptr) | |
libc.free(neg_code_ptr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment