Last active
August 6, 2019 21:43
-
-
Save JannesKlaas/6c0ebb9c99a53b4c92e30548d3674886 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(model,epochs): | |
# Train | |
#Reseting the win counter | |
win_cnt = 0 | |
# We want to keep track of the progress of the AI over time, so we save its win count history | |
win_hist = [] | |
#Epochs is the number of games we play | |
for e in range(epochs): | |
loss = 0. | |
#Resetting the game | |
env.reset() | |
game_over = False | |
# get initial input | |
input_t = env.observe() | |
while not game_over: | |
#The learner is acting on the last observed game screen | |
#input_t is a vector containing representing the game screen | |
input_tm1 = input_t | |
#Take a random action with probability epsilon | |
if np.random.rand() <= epsilon: | |
#Eat something random from the menu | |
action = np.random.randint(0, num_actions, size=1) | |
else: | |
#Choose yourself | |
#q contains the expected rewards for the actions | |
q = model.predict(input_tm1) | |
#We pick the action with the highest expected reward | |
action = np.argmax(q[0]) | |
# apply action, get rewards and new state | |
input_t, reward, game_over = env.act(action) | |
#If we managed to catch the fruit we add 1 to our win counter | |
if reward == 1: | |
win_cnt += 1 | |
#Uncomment this to render the game here | |
#display_screen(action,3000,inputs[0]) | |
""" | |
The experiences < s, a, r, s’ > we make during gameplay are our training data. | |
Here we first save the last experience, and then load a batch of experiences to train our model | |
""" | |
# store experience | |
exp_replay.remember([input_tm1, action, reward, input_t], game_over) | |
# Load batch of experiences | |
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size) | |
# train model on experiences | |
batch_loss = model.train_on_batch(inputs, targets) | |
#sum up loss over all batches in an epoch | |
loss += batch_loss | |
win_hist.append(win_cnt) | |
return win_hist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment