Last active
December 6, 2016 03:18
-
-
Save breeko/aa84e9f461eb6a1ca825be3fe3b75521 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gzip | |
import json | |
def get_relative_frame(frame, position, new_position=(0,0)): | |
""" Rolls a frame such that position is moved to new_position | |
Input: | |
frame: 2-dimensional array | |
position: tuple (row, col) | |
new_position: tuple (row,col) | |
Output: | |
np.array | |
e.g. | |
a = np.arange(9).reshape(3,3) | |
# array([[0,1,2], | |
[3,4,5], | |
[6,7,8]]) | |
get_relative_frame(a, (1,1), new_position=(0,0)) | |
# array([[4,5,3], | |
[7,8,6], | |
[1,2,0]]) | |
""" | |
r, c = position | |
new_r, new_c = new_position | |
diff_r, diff_c = new_r - r, new_c - c | |
return np.roll(np.roll(frame, diff_r, axis=0), diff_c, axis=1) | |
print("reading games...") | |
player_name = "djma v3" | |
game_dir = "games/djma_v3" | |
game_ids = [game for game in os.listdir("{}/".format(game_dir)) if game.endswith(".hlt.gzip")] | |
games = [] | |
for game_id in game_ids[:200]: | |
f = gzip.open("{}/{}".format(game_dir, game_id), 'r') | |
file_content=f.read().decode('utf-8') | |
game = json.loads(file_content) | |
games.append(game) | |
frames_per_game = 25 | |
positions_per_frame = 15 | |
# Since boards have variable sizes, we limit the X data to the nearest width x height board | |
width = 10 | |
height = 10 | |
first_n_frames = 50 | |
X = [] | |
y = [] | |
print("processing games...") | |
for game in games: | |
player = game["player_names"].index(player_name) + 1 | |
productions = game["productions"] | |
productions = np.array(productions) | |
productions = productions / 20. | |
for idx in np.random.choice(min(first_n_frames, len(game["moves"])), size=frames_per_game): | |
frame = np.array(game["frames"][idx]) | |
moves = np.array(game["moves"][idx]) | |
owners = frame[:,:,0] | |
player_mask = owners == player | |
enemy_mask = (owners != player) & (owners != 0) | |
open_mask = owners == 0 | |
owners[player_mask] = -1 | |
owners[enemy_mask] = 1 | |
owners[open_mask] = 0 | |
strengths = frame[:,:,1] | |
strengths = np.array(strengths) | |
strengths = strengths / 255. | |
if np.sum(owners == -1) == 0: | |
continue # Player not in this frame | |
positions = np.where(owners == -1) | |
positions = list(zip(*positions)) | |
position_idxes = np.random.choice(range(len(positions)), size=min(len(positions),positions_per_frame), replace=False) | |
for position_idx in position_idxes: | |
position = positions[position_idx] | |
new_y = height//2 | |
new_x = width//2 | |
relative_strengths = get_relative_frame(strengths, position, new_position=(new_y, new_x))[:height, :width] | |
relative_owners = get_relative_frame(owners, position, new_position=(new_y, new_x))[:height, :width] | |
relative_productions = get_relative_frame(productions, position, new_position=(new_y, new_x))[:height, :width] | |
relative_combined = np.hstack([relative_strengths, relative_owners, relative_productions]) | |
if relative_combined.size == width * height * 3: | |
# At least width x height | |
X.append(relative_combined.ravel()) | |
r,c = position | |
y.append(moves[r][c]) | |
X = np.array(X) | |
y = np.array(y) | |
from sklearn.neural_network import MLPClassifier | |
from sklearn.model_selection import train_test_split | |
print("learning...") | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0) | |
alpha = 0.01 | |
clf = MLPClassifier(activation="relu",alpha=alpha, hidden_layer_sizes=(512,512,512), random_state=0, learning_rate="adaptive",verbose=True, early_stopping=True) | |
clf.fit(X_train, y_train) | |
print("train score: {:02f} test score: {:02f}".format(clf.score(X_train, y_train),clf.score(X_test, y_test))) | |
print("Non-zero score: {:02f}".format(clf.score(X_test[y_test != 0], y_test[y_test != 0]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment