from callbacks import ClearTrainingOutput, PlotTraining
Last active
November 5, 2020 11:15
-
-
Save OrenBochman/1a85db1d0d1bfbb34fa1fbbded7bb6da to your computer and use it in GitHub Desktop.
ml auxiliary code for (layers, callbacks, visualizations)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import IPython | |
# class CostFunctionPlot(tf.keras.callbacks.Callback): | |
# """ | |
# plot the cost dunction of a linear regreassion | |
# c.f. https://towardsdatascience.com/hyperparameter-tuning-with-callbacks-in-keras-5230f51f29b3 | |
# """ | |
# def __init__(self): | |
# """ | |
# """ | |
# self.weight_history = [] | |
# self.bias_history = [] | |
# def on_batch_end(self, batch, logs): | |
# """ | |
# """ | |
# weight, bias = self.model.get_weights() | |
# B = bias[0] | |
# W = weight[0][0] | |
# self.weight_history.append(W) | |
# self.bias_history.append(B) | |
# #TODO plot a countrplot of weights and bias | |
class PlotTraining(tf.keras.callbacks.Callback): | |
""" | |
keras callback to plot metrics during training | |
""" | |
def __init__(self, sample_rate=1, zoom=1): | |
self.sample_rate = sample_rate | |
self.step = 0 | |
self.zoom = zoom | |
self.steps_per_epoch = 60000//BATCH_SIZE | |
def on_train_begin(self, logs={}): | |
self.batch_history = {} | |
self.batch_step = [] | |
self.epoch_history = {} | |
self.epoch_step = [] | |
self.fig, self.axes = plt.subplots(1, 2, figsize=(16, 7)) | |
plt.ioff() | |
def on_batch_end(self, batch, logs={}): | |
if (batch % self.sample_rate) == 0: | |
self.batch_step.append(self.step) | |
for k,v in logs.items(): | |
# do not log "batch" and "size" metrics that do not change | |
# do not log training accuracy "acc" | |
if k=='batch' or k=='size':# or k=='acc': | |
continue | |
self.batch_history.setdefault(k, []).append(v) | |
self.step += 1 | |
def on_epoch_end(self, epoch, logs={}): | |
plt.close(self.fig) | |
self.axes[0].cla() | |
self.axes[1].cla() | |
self.axes[0].set_ylim(0, 1.2/self.zoom) | |
self.axes[1].set_ylim(1-1/self.zoom/2, 1+0.1/self.zoom/2) | |
self.epoch_step.append(self.step) | |
for k,v in logs.items(): | |
# only log validation metrics | |
if not k.startswith('val_'): | |
continue | |
self.epoch_history.setdefault(k, []).append(v) | |
display.clear_output(wait=True) | |
for k,v in self.batch_history.items(): | |
self.axes[0 if k.endswith('loss') else 1].plot(np.array(self.batch_step) / self.steps_per_epoch, v, label=k) | |
for k,v in self.epoch_history.items(): | |
self.axes[0 if k.endswith('loss') else 1].plot(np.array(self.epoch_step) / self.steps_per_epoch, v, label=k, linewidth=3) | |
self.axes[0].legend() | |
self.axes[1].legend() | |
self.axes[0].set_xlabel('epochs') | |
self.axes[1].set_xlabel('epochs') | |
self.axes[0].minorticks_on() | |
self.axes[0].grid(True, which='major', axis='both', linestyle='-', linewidth=1) | |
self.axes[0].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5) | |
self.axes[1].minorticks_on() | |
self.axes[1].grid(True, which='major', axis='both', linestyle='-', linewidth=1) | |
self.axes[1].grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5) | |
display.display(self.fig) | |
class ClearTrainingOutput(tf.keras.callbacks.Callback): | |
""" | |
clear output after traing each model for use in keras hyperparmeter tuning | |
TODO: summerise top results so far | |
TODO: summerise overall progress. | |
""" | |
def on_train_end(*args, **kwargs): | |
IPython.display.clear_output(wait = True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import the necessary packages | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
class LearningRateDecay: | |
def plot(self, epochs, title="Learning Rate Schedule"): | |
# compute the set of learning rates for each corresponding | |
# epoch | |
lrs = [self(i) for i in epochs] | |
# the learning rate schedule | |
plt.style.use("ggplot") | |
plt.figure() | |
plt.plot(epochs, lrs) | |
plt.title(title) | |
plt.xlabel("Epoch #") | |
plt.ylabel("Learning Rate") | |
def plot(self,epochs,lr_func): | |
xx = np.arange(epochs+1, dtype=np.float) | |
y = [lr_decay(x) for x in xx] | |
fig, ax = plt.subplots(figsize=(9, 6)) | |
ax.set_xlabel('epochs') | |
ax.set_title('Learning rate\ndecays from {:0.3g} to {:0.3g}'.format(y[0], y[-2])) | |
ax.minorticks_on() | |
ax.grid(True, which='major', axis='both', linestyle='-', linewidth=1) | |
ax.grid(True, which='minor', axis='both', linestyle=':', linewidth=0.5) | |
ax.step(xx,y, linewidth=3, where='post') | |
display.display(fig) | |
class StepDecay(LearningRateDecay): | |
def __init__(self, initAlpha=0.01, factor=0.25, dropEvery=10): | |
# store the base initial learning rate, drop factor, and | |
# epochs to drop every | |
self.initAlpha = initAlpha | |
self.factor = factor | |
self.dropEvery = dropEvery | |
def __call__(self, epoch): | |
# compute the learning rate for the current epoch | |
exp = np.floor((1 + epoch) / self.dropEvery) | |
alpha = self.initAlpha * (self.factor ** exp) | |
# return the learning rate | |
return float(alpha) | |
class cyclic_schedule | |
def __init__(self, max_lr,min_lr, **kwargs): | |
self.max_lr=max_lr | |
self.min_lr=min_lr | |
def cyclic_scheduler(epoch,lr): | |
if log2(epoch) > 8 : | |
return lr * tf.math.exp(-0.001) | |
elif log2(epoch) ==8 : | |
return self.max_lr | |
elif log2(epoch) > 4 : | |
return lr * tf.math.exp(-0.01) | |
elif log2(epoch) == 4 : | |
return self.max_lr | |
elif epoch==0: | |
return self.max_lr | |
else: | |
return lr * tf.math.exp(-0.1) | |
def scheduler(epoch, lr): | |
if epoch < 10: | |
return lr | |
else: | |
return lr * tf.math.exp(-0.1) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import string | |
import numpy as np | |
from nltk.corpus import stopwords | |
from nltk.stem import PorterStemmer | |
from nltk.tokenize import TweetTokenizer | |
def process_tweet(tweet): | |
"""Process tweet function. | |
Input: | |
tweet: a string containing a tweet | |
Output: | |
tweets_clean: a list of words containing the processed tweet | |
""" | |
stemmer = PorterStemmer() | |
stopwords_english = stopwords.words('english') | |
# remove stock market tickers like $GE | |
tweet = re.sub(r'\$\w*', '', tweet) | |
# remove old style retweet text "RT" | |
tweet = re.sub(r'^RT[\s]+', '', tweet) | |
# remove hyperlinks | |
tweet = re.sub(r'https?:\/\/.*[\r\n]*', '', tweet) | |
# remove hashtags | |
# only removing the hash # sign from the word | |
tweet = re.sub(r'#', '', tweet) | |
# tokenize tweets | |
tokenizer = TweetTokenizer(preserve_case=False, strip_handles=True, | |
reduce_len=True) | |
tweet_tokens = tokenizer.tokenize(tweet) | |
tweets_clean = [] | |
for word in tweet_tokens: | |
if (word not in stopwords_english and # remove stopwords | |
word not in string.punctuation): # remove punctuation | |
# tweets_clean.append(word) | |
stem_word = stemmer.stem(word) # stemming word | |
tweets_clean.append(stem_word) | |
return tweets_clean | |
def build_freqs(tweets, ys): | |
"""Build frequencies. | |
Input: | |
tweets: a list of tweets | |
ys: an m x 1 array with the sentiment label of each tweet | |
(either 0 or 1) | |
Output: | |
freqs: a dictionary mapping each (word, sentiment) pair to its | |
frequency | |
""" | |
# Convert np array to list since zip needs an iterable. | |
# The squeeze is necessary or the list ends up with one element. | |
# Also note that this is just a NOP if ys is already a list. | |
yslist = np.squeeze(ys).tolist() | |
# Start with an empty dictionary and populate it by looping over all tweets | |
# and over all processed words in each tweet. | |
freqs = {} | |
for y, tweet in zip(yslist, tweets): | |
for word in process_tweet(tweet): | |
pair = (word, y) | |
if pair in freqs: | |
freqs[pair] += 1 | |
else: | |
freqs[pair] = 1 | |
return freqs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Layer | |
from keras import backend as K | |
class RBFLayer(Layer): | |
def __init__(self, units, gamma, **kwargs): | |
super(RBFLayer, self).__init__(**kwargs) | |
self.units = units | |
self.gamma = K.cast_to_floatx(gamma) | |
def build(self, input_shape): | |
# print(input_shape) | |
# print(self.units) | |
self.mu = self.add_weight(name='mu', | |
shape=(int(input_shape[1]), self.units), | |
initializer='uniform', | |
trainable=True) | |
super(RBFLayer, self).build(input_shape) | |
def call(self, inputs): | |
diff = K.expand_dims(inputs) - self.mu | |
l2 = K.sum(K.pow(diff, 2), axis=1) | |
res = K.exp(-1 * self.gamma * l2) | |
return res | |
def compute_output_shape(self, input_shape): | |
return (input_shape[0], self.units) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.utils.layer_utils import count_params | |
colors = dict(underline = '\033[4m', bold='\033[1m', end='\033[0m', | |
red = '\033[91m', green = '\033[92m', blue = '\033[94m', | |
cyan = '\033[96m', white = '\033[97m', yellow = '\033[93m', | |
magenta = '\033[95m', grey = '\033[90m', black = '\033[90m', | |
default = '\033[99m' ) | |
def colprint(col,fmt,data): | |
""" | |
print text in color | |
c.f. https://en.wikipedia.org/wiki/ANSI_escape_code | |
""" | |
color = colors['default'] | |
if col in colors.keys() color = colors[col] | |
# print positive in greeen | |
print(col + fmt.format(data)) | |
def print_parameter_count(model): | |
""" | |
get the count of trainable an untrainable (frozen) parameters in a model | |
""" | |
trainable = count_params(model.trainable_weights) | |
untrainable = count_params(model.non_trainable_weights) | |
print('Total params: {:,}'.format(trainable + untrainable)) | |
print('Trainable params: {:,}'.format(trainable)) | |
print('Non-trainable params: {:,}'.format(untrainable)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_image(i, predictions_array, true_labels, images): | |
""" | |
plot an image | |
TODO work with datasets | |
parameters: | |
i image index to use | |
i, predictions_array, true_labels, images | |
""" | |
predictions_array, true_label, img = predictions_array, true_labels[i], images[i] | |
plt.grid(False) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.imshow(tf.squeeze(img), cmap=plt.cm.binary) | |
predicted_label = np.argmax(predictions_array) | |
classes=gen_data_classes() | |
if predicted_label == true_label: | |
color = 'blue' | |
else: | |
color = 'red' | |
plt.xlabel("{} {:2.0f}% ({})".format(classes[predicted_label], | |
100*np.max(predictions_array), | |
classes[true_label]), | |
color=color) | |
def plot_value_array(i, predictions_array, true_label): | |
predictions_array, true_label = predictions_array, true_label[i] | |
plt.grid(False) | |
plt.xticks(range(10)) | |
plt.yticks([]) | |
thisplot = plt.bar(range(10), predictions_array, color="#777777") | |
plt.ylim([0, 1]) | |
predicted_label = np.argmax(predictions_array) | |
thisplot[predicted_label].set_color('red') | |
thisplot[true_label].set_color('blue') | |
# Plot the first X test images, their predicted labels, and the true labels. | |
# Color correct predictions in blue and incorrect predictions in red. | |
def diagnostic_plot(image,actual,expected,rows=5,cols=3): | |
num_rows,num_cols = rows, cols | |
x,y,p=image,actual,expected | |
num_images = num_rows*num_cols | |
plt.figure(figsize=(2*2*num_cols, 2*num_rows)) | |
for i in range(num_images): | |
plt.subplot(num_rows, 2*num_cols, 2*i+1) | |
plot_image(i, p[i], y, x) | |
plt.subplot(num_rows, 2*num_cols, 2*i+2) | |
plot_value_array(i, p[i], y) | |
plt.tight_layout() | |
# plt.rcParams["figure.figsize"] = (20,10) | |
plt.show() | |
diagnostic_plot(images=x_validate,actual=y_validate,expected=p_validate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist | |
def dataset_to_numpy_util(): | |
# Matplotlib config | |
plt.ioff() | |
plt.rc('image', cmap='gray_r') | |
plt.rc('grid', linewidth=1) | |
plt.rc('xtick', top=False, bottom=False, labelsize='large') | |
plt.rc('ytick', left=False, right=False, labelsize='large') | |
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white') | |
plt.rc('text', color='a8151a') | |
plt.rc('figure', facecolor='F0F0F0', figsize=(16,9)) | |
# Matplotlib fonts | |
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf") | |
# pull a batch from the datasets. This code is not very nice, it gets much better in eager mode (TODO) | |
def dataset_to_numpy_util(training_dataset, validation_dataset, N): | |
# get one batch from each: 10000 validation digits, N training digits | |
batch_train_ds = training_dataset.unbatch().batch(N) | |
# eager execution: loop through datasets normally | |
if tf.executing_eagerly(): | |
for validation_digits, validation_labels in validation_dataset: | |
validation_digits = validation_digits.numpy() | |
validation_labels = validation_labels.numpy() | |
break | |
for training_digits, training_labels in batch_train_ds: | |
training_digits = training_digits.numpy() | |
training_labels = training_labels.numpy() | |
break | |
else: | |
v_images, v_labels = validation_dataset.make_one_shot_iterator().get_next() | |
t_images, t_labels = batch_train_ds.make_one_shot_iterator().get_next() | |
# Run once, get one batch. Session.run returns numpy results | |
with tf.Session() as ses: | |
(validation_digits, validation_labels, | |
training_digits, training_labels) = ses.run([v_images, v_labels, t_images, t_labels]) | |
# these were one-hot encoded in the dataset | |
validation_labels = np.argmax(validation_labels, axis=1) | |
training_labels = np.argmax(training_labels, axis=1) | |
return (training_digits, training_labels, | |
validation_digits, validation_labels) | |
# create digits from local fonts for testing | |
def create_digits_from_local_fonts(n): | |
font_labels = [] | |
img = PIL.Image.new('LA', (28*n, 28), color = (0,255)) # format 'LA': black in channel 0, alpha in channel 1 | |
font1 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'DejaVuSansMono-Oblique.ttf'), 25) | |
font2 = PIL.ImageFont.truetype(os.path.join(MATPLOTLIB_FONT_DIR, 'STIXGeneral.ttf'), 25) | |
d = PIL.ImageDraw.Draw(img) | |
for i in range(n): | |
font_labels.append(i%10) | |
d.text((7+i*28,0 if i<10 else -4), str(i%10), fill=(255,255), font=font1 if i<10 else font2) | |
font_digits = np.array(img.getdata(), np.float32)[:,0] / 255.0 # black in channel 0, alpha in channel 1 (discarded) | |
font_digits = np.reshape(np.stack(np.split(np.reshape(font_digits, [28, 28*n]), n, axis=1), axis=0), [n, 28*28]) | |
return font_digits, font_labels | |
# utility to display a row of digits with their predictions | |
def display_digits(digits, predictions, labels, title, n): | |
fig = plt.figure(figsize=(13,3)) | |
digits = np.reshape(digits, [n, 28, 28]) | |
digits = np.swapaxes(digits, 0, 1) | |
digits = np.reshape(digits, [28, 28*n]) | |
plt.yticks([]) | |
plt.xticks([28*x+14 for x in range(n)], predictions) | |
plt.grid(b=None) | |
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()): | |
if predictions[i] != labels[i]: t.set_color('red') # bad predictions in red | |
plt.imshow(digits) | |
plt.grid(None) | |
plt.title(title) | |
display.display(fig) | |
# utility to display multiple rows of digits, sorted by unrecognized/recognized status | |
def display_top_unrecognized(digits, predictions, labels, n, lines): | |
idx = np.argsort(predictions==labels) # sort order: unrecognized first | |
for i in range(lines): | |
display_digits(digits[idx][i*n:(i+1)*n], predictions[idx][i*n:(i+1)*n], labels[idx][i*n:(i+1)*n], | |
"{} sample validation digits out of {} with bad predictions in red and sorted first".format(n*lines, len(digits)) if i==0 else "", n) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment