Created
October 30, 2018 12:47
-
-
Save tomat/9727337920dedbbf47f5a2db65a9d696 to your computer and use it in GitHub Desktop.
TF part 1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| # List out our bandits. Currently bandit 4 (index#3) is set to most often provide a positive reward. | |
| # Reward of 1 if given if the random number 0-1 is lower than the bandit value | |
| bandits = [0.5, 0.1, 0.3, 0.8] | |
| num_bandits = len(bandits) | |
| def pullBandit(bandit): | |
| # bandit = bandit value, not bandit index | |
| # Get a random number. | |
| result = np.random.random(1) | |
| if result < bandit: | |
| # Return a positive reward. | |
| return 1 | |
| else: | |
| # Return a negative reward. | |
| return -1 | |
| tf.reset_default_graph() | |
| # These two lines established the feed-forward part of the network. This does the actual choosing. | |
| # Will look like [1, 1, 1, 1] at first run (one 1 for each bandit) | |
| weights = tf.Variable(tf.ones([num_bandits])) | |
| # This is the default action, will pull the bandit with the most weight | |
| chosen_action = tf.argmax(weights, 0) | |
| # The next six lines establish the training proceedure. We feed the reward and chosen action into the network | |
| # to compute the loss, and use it to update the network. | |
| # Placeholder for reward input for each run | |
| reward_holder = tf.placeholder(shape=[1], dtype=tf.float32) | |
| # Placeholder for action input for each run | |
| action_holder = tf.placeholder(shape=[1], dtype=tf.int32) | |
| # Which bandit was pulled? | |
| responsible_weight = tf.slice(weights, action_holder, [1]) | |
| # Loss-function (not sure why log) | |
| loss = -(tf.log(responsible_weight) * reward_holder) | |
| # Optimize for the loss function returning a smaller value | |
| optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) | |
| update = optimizer.minimize(loss) | |
| # Set total number of steps to train agent on | |
| total_episodes = 1000 | |
| # Set scoreboard for bandits to 0, just for logging | |
| total_reward = np.zeros(num_bandits) | |
| # Set the chance of taking a random action at each step | |
| e = 0.5 | |
| init = tf.initialize_all_variables() | |
| # Launch the tensorflow graph | |
| with tf.Session() as sess: | |
| # First run once with init | |
| sess.run(init) | |
| # Loop through the runs | |
| i = 0 | |
| while i < total_episodes: | |
| # Choose either a random bandit or one from our network | |
| # "from our network" = the currently known "best" bandit | |
| if np.random.rand(1) < e: | |
| action = np.random.randint(num_bandits) | |
| else: | |
| action = sess.run(chosen_action) | |
| # Get our reward from picking one of the bandits (1 or -1) | |
| reward = pullBandit(bandits[action]) | |
| # Update the network, this is where stuff happens | |
| # reward_holder tensor is replaced with reward in this run | |
| # action_holder tensor is replaced with action in this run | |
| # ww returns the new weights after the run | |
| # The only variable that is trainable is "weights", so that's the only thing that will be changed to try to | |
| # find which bandit gives the most reward. When "weights" is changed the next run may choose another action in | |
| # "chosen_action" (if not doing a random run) and that will affect the next adjustment. | |
| _,resp,ww = sess.run( | |
| [update, responsible_weight, weights], | |
| feed_dict={ | |
| reward_holder: [reward], | |
| action_holder: [action], | |
| } | |
| ) | |
| # Update our running tally of scores, just for logging | |
| total_reward[action] += reward | |
| if i % 50 == 0: | |
| print "Running reward for the " + str(num_bandits) + " bandits: " + str(total_reward) | |
| i+=1 | |
| print "The agent thinks bandit " + str(np.argmax(ww)+1) + " is the most promising...." | |
| if np.argmax(ww) == np.argmax(np.array(bandits)): | |
| print "...and it was right!" | |
| else: | |
| print "...and it was wrong!" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment