-
-
Save davexpro/7069b9fe65b38415c4f3cc563acb981b to your computer and use it in GitHub Desktop.
Keras 多 GPU 同步训练
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers.merge import Concatenate | |
from keras.layers.core import Lambda | |
from keras.models import Model | |
import tensorflow as tf | |
def make_parallel(model, gpu_count): | |
def get_slice(data, idx, parts): | |
shape = tf.shape(data) | |
size = tf.concat([ shape[:1] // parts, shape[1:] ],axis=0) | |
stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ],axis=0) | |
start = stride * idx | |
return tf.slice(data, start, size) | |
outputs_all = [] | |
for i in range(len(model.outputs)): | |
outputs_all.append([]) | |
#Place a copy of the model on each GPU, each getting a slice of the batch | |
for i in range(gpu_count): | |
with tf.device('/gpu:%d' % i): | |
with tf.name_scope('tower_%d' % i) as scope: | |
inputs = [] | |
#Slice each input into a piece for processing on this GPU | |
for x in model.inputs: | |
input_shape = tuple(x.get_shape().as_list())[1:] | |
slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x) | |
inputs.append(slice_n) | |
outputs = model(inputs) | |
if not isinstance(outputs, list): | |
outputs = [outputs] | |
#Save all the outputs for merging back together later | |
for l in range(len(outputs)): | |
outputs_all[l].append(outputs[l]) | |
# merge outputs on CPU | |
with tf.device('/cpu:0'): | |
merged = [] | |
for outputs in outputs_all: | |
merged.append(Concatenate(axis=0)(outputs)) | |
return Model(model.inputs, merged) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment