Skip to content

Instantly share code, notes, and snippets.

@amarioncosmo
Created August 23, 2018 08:06
Show Gist options
  • Save amarioncosmo/c05038a41acb08c01bd0fe2da4a20f79 to your computer and use it in GitHub Desktop.
Save amarioncosmo/c05038a41acb08c01bd0fe2da4a20f79 to your computer and use it in GitHub Desktop.
Get a keras model graph in tensorboard
# Launch tensorboard
# $ tensorboard --logdir=~/tensorboard.log
import tensorflow as tf
from keras.layers import Dense, Dropout, Conv3D, MaxPooling3D, GlobalAveragePooling3D
from keras.models import Sequential
import keras.backend as K
model = Sequential()
model.add(Conv3D(8, kernel_size=(3, 3, 3), input_shape=(None,None,None,1), padding='same'))
model.add(MaxPooling3D(pool_size=(2, 2, 2), padding='same'))
model.add(Conv3D(8, kernel_size=(3, 3, 3), padding='same'))
model.add(Conv3D(8, kernel_size=(3, 3, 3), padding='same'))
model.add(GlobalAveragePooling3D())
model.add(Dense(32))
model.add(Dropout(0.1))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')
session = K.get_session()
writer = tf.summary.FileWriter('~/tensorboard.log', session.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment