Created
June 11, 2018 16:42
-
-
Save benoitdescamps/e15be87dc11a577a0d3c8a346fab5b71 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_cost(target,Q,action_indices): | |
""" | |
Cost-function of the Q-matrix attempting to approximate the reward-function | |
:param tf.placeholder target: placeholder for the values of the registered rewards | |
:param tf.placeholder Q: output of the Q-matrix the registered state | |
:param tf.placeholder action_indices: placeholder for the indices of the registered actions | |
:return: tf.tensor mean_squared error the reward vs Q-matrix | |
""" | |
row_indices = tf.range(tf.shape(action_indices)[0]) | |
full_indices = tf.stack([row_indices, action_indices], axis=1) | |
q_values = tf.gather_nd(Q, full_indices) | |
return tf.losses.mean_squared_error(labels=target,predictions=q_values) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment