Skip to content

Instantly share code, notes, and snippets.

@ASvyatkovskiy
Created September 29, 2017 00:49
Show Gist options
  • Save ASvyatkovskiy/170b264f9d740e1830d19fcc906adedc to your computer and use it in GitHub Desktop.
Save ASvyatkovskiy/170b264f9d740e1830d19fcc906adedc to your computer and use it in GitHub Desktop.
TensorBoard AUC, Keras, example LSTM on IMDB dataset
from __future__ import print_function
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Embedding
from keras.layers import LSTM
from keras.datasets import imdb
from keras.callbacks import TensorBoard, Callback
import keras.backend as K
import tensorflow as tf
from sklearn.metrics import roc_auc_score
#import pdb
class TensorBoardAUC(TensorBoard):
def __init__(self, log_dir='./logs',
histogram_freq=0,
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None):
TensorBoard.__init__(self,log_dir=log_dir,
histogram_freq=histogram_freq,
batch_size=batch_size,
write_graph=write_graph,
write_grads=write_grads,
write_images=write_images,
embeddings_freq=embeddings_freq,
embeddings_layer_names=embeddings_layer_names,
embeddings_metadata=embeddings_metadata)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if not self.validation_data and self.histogram_freq:
raise ValueError('If printing histograms, validation_data must be '
'provided, and cannot be a generator.')
if self.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
val_data = self.validation_data
tensors = (self.model.inputs +
self.model.targets +
self.model.sample_weights)
if self.model.uses_learning_phase:
tensors += [K.learning_phase()]
assert len(val_data) == len(tensors)
val_size = val_data[0].shape[0]
i = 0
while i < val_size:
step = min(self.batch_size, val_size - i)
if self.model.uses_learning_phase:
# do not slice the learning phase
batch_val = [x[i:i + step] for x in val_data[:-1]]
batch_val.append(val_data[-1])
else:
batch_val = [x[i:i + step] for x in val_data]
assert len(batch_val) == len(tensors)
feed_dict = dict(zip(tensors, batch_val))
result = self.sess.run([self.merged], feed_dict=feed_dict)
summary_str = result[0]
self.writer.add_summary(summary_str, epoch)
i += self.batch_size
if self.embeddings_freq and self.embeddings_ckpt_path:
if epoch % self.embeddings_freq == 0:
self.saver.save(self.sess,
self.embeddings_ckpt_path,
epoch)
#quick hack
b = {'val_auc':0.0}
logs = {**logs,**b}
for name, value in logs.items():
if name in ['batch', 'size']:
continue
if name == 'val_auc':
X_val, y_val = self.validation_data[0], self.validation_data[1]
y_pred = self.model.predict_proba(X_val, verbose=0)
value = roc_auc_score(y_val, y_pred)
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.writer.add_summary(summary, epoch)
#val_auc summary
#summary = tf.Summary()
#summary_value = summary.value.add()
#X_val, y_val = self.validation_data[0], self.validation_data[1]
#y_pred = self.model.predict_proba(X_val, verbose=0)
#score = roc_auc_score(y_val, y_pred)
#summary_value.simple_value = score
#summary_value.tag = "val_auc"
#self.writer.add_summary(summary, epoch)
self.writer.flush()
max_features = 20000
maxlen = 80 # cut texts after this number of words (among top max_features most common words)
batch_size = 256
print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')
print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))
# try using different optimizers and different optimizer configs
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
TB = TensorBoardAUC(log_dir='./logs', histogram_freq=1, batch_size=batch_size, write_graph=True)
print('Train...')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=20,
validation_data=(x_test, y_test),callbacks=[TB])
score, acc = model.evaluate(x_test, y_test,
batch_size=batch_size)
print('Test score:', score)
print('Test accuracy:', acc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment