Created
May 9, 2019 16:06
-
-
Save dongkwan-kim/aa99fbfef1f18c7d38ab707ad2f164e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils import * | |
params = {} | |
def create_variable(scope, name, shape, trainable=True, on_cpu=True, **kwargs) -> tf.Variable: | |
def _create_variable(): | |
with tf.variable_scope(scope): | |
_w = tf.get_variable(name, shape, trainable=trainable, **kwargs) | |
params[_w.name] = _w | |
return _w | |
if on_cpu: | |
with tf.device("/cpu:0"): | |
w = _create_variable() | |
else: | |
w = _create_variable() | |
return w | |
def get_variable(scope, name, trainable=True) -> tf.Variable: | |
with tf.variable_scope(scope, reuse=True): | |
w = tf.get_variable(name, trainable=trainable) | |
params[w.name] = w | |
return w | |
def get_toy_data(n, xd): | |
xs = np.concatenate([np.random.random((n, xd)) / 2, np.random.random((n, xd)) / 2 + 0.5]) | |
ys = np.concatenate([np.zeros((n,), dtype=np.int), np.ones((n,), dtype=np.int)]) | |
permut = np.random.permutation(len(xs)) | |
xs = xs[permut] | |
ys = ys[permut] | |
return xs, np.eye(2)[ys] | |
def average_gradients(tower_grads): | |
"""Calculate the average gradient for each shared variable across all towers. | |
Note that this function provides a synchronization point across all towers. | |
Args: | |
tower_grads: List of lists of (gradient, variable) tuples. The outer list | |
is over individual gradients. The inner list is over the gradient | |
calculation for each tower. | |
Returns: | |
List of pairs of (gradient, variable) where the gradient has been averaged | |
across all towers. | |
""" | |
average_grads = [] | |
for grad_and_vars in zip(*tower_grads): | |
# Note that each grad_and_vars looks like the following: | |
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) | |
grads = [] | |
for g, _ in grad_and_vars: | |
# Add 0 dimension to the gradients to represent the tower. | |
expanded_g = tf.expand_dims(g, 0) | |
# Append on a 'tower' dimension which we will average over below. | |
grads.append(expanded_g) | |
# Average over the 'tower' dimension. | |
grad = tf.concat(axis=0, values=grads) | |
grad = tf.reduce_mean(grad, 0) | |
# Keep in mind that the Variables are redundant because they are shared | |
# across towers. So .. we will just return the first tower's pointer to | |
# the Variable. | |
v = grad_and_vars[0][1] | |
grad_and_var = (grad, v) | |
average_grads.append(grad_and_var) | |
return average_grads | |
def main(): | |
n = 6000 | |
xd = 14 * 14 | |
hd = 100 | |
xs, ys = get_toy_data(n, xd) | |
X = tf.placeholder(tf.float32, [None, xd], name="X") | |
Y = tf.placeholder(tf.float32, [None, 2], name="Y") | |
w1 = create_variable("layer1", "weight", (xd, hd)) | |
h = tf.nn.relu(tf.matmul(X, w1)) | |
w2 = create_variable("layer2", "weight", (hd, hd)) | |
h = tf.nn.relu(tf.matmul(h, w2)) | |
w3 = create_variable("layer3", "weight", (hd, 2)) | |
h = tf.matmul(h, w3) | |
hhat = tf.nn.softmax(h) | |
opt = tf.train.AdamOptimizer(learning_rate=0.001, name="opt") | |
gpu_names = get_available_gpu_names([1]) | |
batch_size = 300 | |
batch_size_per_gpu = batch_size // len(gpu_names) | |
grad_list = [] | |
loss_list = [] | |
with tf.variable_scope(tf.get_variable_scope()): | |
for i, gpu_name in enumerate(gpu_names): | |
with tf.device(gpu_name): | |
idx_start = i * batch_size_per_gpu | |
idx_end = (i + 1) * batch_size_per_gpu | |
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | |
logits=h[idx_start:idx_end], labels=Y[idx_start:idx_end], | |
)) | |
tf.get_variable_scope().reuse_variables() | |
grad = opt.compute_gradients(loss) | |
loss_list.append(loss) | |
grad_list.append(grad) | |
grads = average_gradients(grad_list) | |
train_op = opt.apply_gradients(grads) | |
sess = tf.Session(config=tf.ConfigProto( | |
allow_soft_placement=True, | |
log_device_placement=True)) | |
sess.run(tf.global_variables_initializer()) | |
num_batch = n // batch_size | |
for epoch in range(100): | |
total_loss = 0 | |
for batch_idx in range(num_batch): | |
idx_start = batch_idx * batch_size | |
idx_end = (batch_idx + 1) * batch_size | |
xs_b = xs[idx_start:idx_end] | |
ys_b = ys[idx_start:idx_end] | |
_, loss_value = sess.run([train_op, loss_list], feed_dict={ | |
X: xs_b, | |
Y: ys_b | |
}) | |
total_loss += np.mean(loss_value) | |
print(total_loss) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment