Created
February 23, 2017 21:40
-
-
Save akiross/2f980976486cf3d7068ae908ba2c753d to your computer and use it in GitHub Desktop.
Google Hash Code 2017 - Youtube video caching solution using Genetic Algorithms
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
'''So, I entered the Google Hash Code 2017 competition with a team named PiedPiper, | |
with a guy from UK and a girl from Palestine, but in the end each one of us worked | |
in her/his own. I came up with a solution using Genetic Algorithms (via DEAP). | |
I used the score function to build an algorithm that used 1-bit encodings to represent | |
the presence of a video over a certain cache server. | |
I wasted some time during the competition, so I managed to run briefly the algo on the | |
smallest dataset: I think result was decent, but then I made an error and let the algo | |
run on larger datasets for too much time, without printing intermediate results. | |
Anyway, it was fun. I hope someone can enjoy the code (sorry, it's a bit messy). | |
~Aki''' | |
import sys | |
from collections import defaultdict | |
from operator import itemgetter | |
from deap import base, creator, tools, algorithms | |
import random | |
def read_data(lines): | |
first_line = next(lines) | |
# Videos, Endpoints, Request descriptors, Cache servers, maX cache | |
V, E, R, C, X = [int(f) for f in first_line.split()] | |
params = {'V': V, 'E': E, 'R': R, 'C': C, 'X': X} | |
#print(V, E, R, C, X) | |
vid_line = next(lines) | |
videos = [int(v) for v in vid_line.split()] | |
params['Vs'] = videos | |
params['Lc'] = dict() | |
params['Ld'] = dict() | |
params['K'] = dict() | |
for e in range(E): | |
ep_line = next(lines) | |
dt_latency, n_caches = [int(v) for v in ep_line.split()] | |
params['Ld'][e] = dt_latency # Latency from endpoint E to main server | |
params['K'][e] = n_caches # Number of cache servers connected to endpoint e | |
params['Lc'][e] = dict() | |
for i in range(n_caches): | |
ca_line = next(lines) | |
c, lat = [int(v) for v in ca_line.split()] | |
params['Lc'][e][c] = lat # Latency from cache server c to endpoint e | |
params['R'] = defaultdict(dict) | |
params['R_'] = defaultdict(dict) # Inverse | |
params['Req'] = list() | |
for r in range(R): | |
req_line = next(lines) | |
Rv, Re, Rn = [int(v) for v in req_line.split()] | |
params['Req'].append((Rv, Re, Rn)) # Original tuple | |
params['R'][Rv][Re] = Rn # Number of requests to video Rv from endpoint Re | |
params['R_'][Re][Rv] = Rn # Same thing, inverted for convenience | |
return params | |
def write_data(cached_videos): | |
# cached_videos provide a dictionary mapping each cache server to a list of videos to store | |
print(len(cached_videos)) | |
for srv in cached_videos: | |
print(srv, ' '.join(str(s) for s in cached_videos[srv])) | |
def score(params, cached_videos): | |
tot_save, tot_req = 0, 0 | |
for Rv, Re, Rn in params['Req']: | |
Ld = params['Ld'][Re] # Latency of serving a video to the endpoint Re | |
# Latencies of serving a video to Re from chsrv {srv: lat] | |
# Filter using chsrv containing the videos Rv | |
Lc = [l for c, l in params['Lc'][Re].items() if Rv in cached_videos[c]] | |
if Lc: | |
liii = [Ld] + Lc | |
L = min(*liii) | |
Li = min(enumerate(liii), key=itemgetter(1))[0] | |
else: | |
L = Ld | |
Li = 0 | |
saving = Ld-L | |
tot_save += saving * Rn | |
tot_req += Rn | |
return 1000 * tot_save / tot_req | |
example_input = '''5 2 4 3 100 | |
50 50 80 30 110 | |
1000 3 | |
0 100 | |
2 200 | |
1 300 | |
500 0 | |
3 0 1500 | |
0 1 1000 | |
4 0 500 | |
1 0 1000''' | |
def test(): | |
params = read_data(iter(example_input.splitlines())) | |
with open('Data/me_at_the_zoo.in') as fd: | |
print(read_data(fd)) | |
cached_videos = defaultdict(list) | |
cached_videos.update({0: [2], 1: [3, 1], 2: [0, 1]}) | |
print("Score", score(params, cached_videos)) | |
write_data(cached_videos) | |
def main(): | |
# Read the data | |
#params = read_data(iter(example_input.splitlines())) | |
with open(sys.argv[1]) as fd: | |
params = read_data(fd) | |
# Setup the GA | |
num_gen = 10 | |
pop_size = 1000 | |
creator.create("FitnessMax", base.Fitness, weights=(1.0,)) | |
creator.create("Individual", list, fitness=creator.FitnessMax) | |
# We will encode each solution in a single list, that will be split evenly among each server | |
# each sub-list contains bits indicating wheter a video goes on the cache server or not | |
# if sum of the videos exceedes server capacity, fitness is inf | |
nv = params['V'] | |
num_bits = params['V'] * params['C'] # Video on each cache server | |
toolbox = base.Toolbox() | |
toolbox.register("attribute", lambda: int(random.random() < 0.01)) | |
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attribute, n=num_bits) | |
toolbox.register("population", tools.initRepeat, list, toolbox.individual) | |
def encode(ind): | |
# Split individual evenly | |
chsrv = [ind[i*nv:(i+1)*nv] for i in range(params['C'])] | |
# Build a list of videos for each cache server | |
cached_videos = dict() | |
max_tot = 0 # Highest sum of video sizes | |
for i, bits in enumerate(chsrv): | |
cached_videos[i] = [j for j, b in enumerate(bits) if b == 1] | |
tot = sum(params['Vs'][j] for j in cached_videos[i]) | |
max_tot = max(max_tot, tot) | |
return cached_videos, max_tot | |
def fitness(ind): | |
cv, tot = encode(ind) | |
if tot > params['X']: | |
return float('-inf'), | |
return score(params, cv), | |
toolbox.register("mate", tools.cxTwoPoint) | |
toolbox.register("mutate", tools.mutFlipBit, indpb=0.01) | |
toolbox.register("select", tools.selTournament, tournsize=3) | |
toolbox.register("evaluate", fitness) | |
pop = toolbox.population(n=pop_size) | |
hof = tools.HallOfFame(1) | |
pop, log = algorithms.eaSimple(pop, toolbox, cxpb=0.5, mutpb=0.2, ngen=num_gen, halloffame=hof, verbose=True) | |
cv, _ = encode(hof[0]) | |
write_data(cv) | |
print(cv, fitness(hof[0])[0], file=sys.stderr) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment