Created
August 3, 2016 19:24
-
-
Save llSourcell/ffdf17074f7fc8508716542d275dadbe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flat_game import carmunk | |
import numpy as np | |
import random | |
import csv | |
from nn import neural_net, LossHistory | |
import os.path | |
import timeit | |
def train_net(model, params): | |
filename = params_to_filename(params) | |
observe = 1000 # Number of frames to observe before training. | |
epsilon = 1 | |
train_frames = 1000000 # Number of frames to play. | |
batchSize = params['batchSize'] | |
buffer = params['buffer'] | |
# Just stuff used below. | |
max_car_distance = 0 | |
car_distance = 0 | |
t = 0 | |
data_collect = [] | |
replay = [] # stores tuples of (S, A, R, S'). | |
loss_log = [] | |
# Create a new game instance. | |
game_state = carmunk.GameState() | |
# Get initial state by doing nothing and getting the state. | |
_, state = game_state.frame_step((2)) | |
# Let's time it. | |
start_time = timeit.default_timer() | |
# Run the frames. | |
while t < train_frames: | |
t += 1 | |
car_distance += 1 | |
# Choose an action. | |
if random.random() < epsilon or t < observe: | |
action = np.random.randint(0, 3) # random | |
else: | |
# Get Q values for each action. | |
qval = model.predict(state, batch_size=1) | |
action = (np.argmax(qval)) # best | |
# Take action, observe new state and get our treat. | |
reward, new_state = game_state.frame_step(action) | |
# Experience replay storage. | |
replay.append((state, action, reward, new_state)) | |
# If we're done observing, start training. | |
if t > observe: | |
# If we've stored enough in our buffer, pop the oldest. | |
if len(replay) > buffer: | |
replay.pop(0) | |
# Randomly sample our experience replay memory | |
minibatch = random.sample(replay, batchSize) | |
# Get training values. | |
X_train, y_train = process_minibatch(minibatch, model) | |
# Train the model on this batch. | |
history = LossHistory() | |
model.fit( | |
X_train, y_train, batch_size=batchSize, | |
nb_epoch=1, verbose=0, callbacks=[history] | |
) | |
loss_log.append(history.losses) | |
# Update the starting state with S'. | |
state = new_state | |
# Decrement epsilon over time. | |
if epsilon > 0.1 and t > observe: | |
epsilon -= (1/train_frames) | |
# We died, so update stuff. | |
if reward == -500: | |
# Log the car's distance at this T. | |
data_collect.append([t, car_distance]) | |
# Update max. | |
if car_distance > max_car_distance: | |
max_car_distance = car_distance | |
# Time it. | |
tot_time = timeit.default_timer() - start_time | |
fps = car_distance / tot_time | |
# Output some stuff so we can watch. | |
print("Max: %d at %d\tepsilon %f\t(%d)\t%f fps" % | |
(max_car_distance, t, epsilon, car_distance, fps)) | |
# Reset. | |
car_distance = 0 | |
start_time = timeit.default_timer() | |
# Save the model every 25,000 frames. | |
if t % 25000 == 0: | |
model.save_weights('saved-models/' + filename + '-' + | |
str(t) + '.h5', | |
overwrite=True) | |
print("Saving model %s - %d" % (filename, t)) | |
# Log results after we're done all frames. | |
log_results(filename, data_collect, loss_log) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment