Last active
September 25, 2018 20:30
-
-
Save benoitdescamps/5511c6d4e8cc2671b800742789660f0d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _create_slots(self, var_list): | |
# Create slots for allocation and later management of additional | |
# variables associated with the variables to train. | |
# for example: the first and second moments. | |
''' | |
for v in var_list: | |
self._zeros_slot(v, "m", self._name) | |
self._zeros_slot(v, "v", self._name) | |
''' | |
def _apply_dense(self, grad, var): | |
#define your favourite variable update | |
# for example: | |
''' | |
# Here we apply gradient descents by substracting the variables | |
# with the gradient times the learning_rate (defined in __init__) | |
var_update = state_ops.assign_sub(var, self.learning_rate * grad) | |
''' | |
#The trick is now to pass the Ops in the control_flow_ops and | |
# eventually groups any particular computation of the slots your | |
# wish to keep track of: | |
# for example: | |
''' | |
m_t = ...m... #do something with m and grad | |
v_t = ...v... # do something with v and grad | |
''' | |
return control_flow_ops.group(*[var_update, m_t, v_t]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment