Skip to content

Instantly share code, notes, and snippets.

@manzke
Created December 26, 2021 20:14
Show Gist options
  • Save manzke/8fed9a1887b53ac4458fa218d4197757 to your computer and use it in GitHub Desktop.
Save manzke/8fed9a1887b53ac4458fa218d4197757 to your computer and use it in GitHub Desktop.
Variational AutoEncoder with Alibi-Detect and Keras
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import tensorflow_datasets as tfds
tf.keras.backend.clear_session()
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, \
Dense, Layer, Reshape, InputLayer, Flatten
from tqdm import tqdm
from tensorflow import keras
from keras import layers
from keras.callbacks import EarlyStopping, ModelCheckpoint
import argparse
from time import time
print(tf.__version__)
from alibi_detect.od import OutlierVAE
from alibi_detect.utils.fetching import fetch_detector
from alibi_detect.utils.perturbation import apply_mask
from alibi_detect.utils.saving import save_detector, load_detector
from alibi_detect.utils.visualize import plot_instance_score, plot_feature_outlier_image
from alibi_detect.models.tensorflow.losses import elbo
logger = tf.get_logger()
#logger.setLevel(logging.ERROR)
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--data", required=True, help="path to the data used for classification")
ap.add_argument("-t", "--test", required=True, help="path to the test data used for classification")
ap.add_argument("--recreate", default=False, type=bool, help="path to the test data used for classification")
ap.add_argument("--threshold", default=0.005, type=float, help="threshold for outliers")
ap.add_argument("-w", "--width", default=456, type=int, help="width for the image")
ap.add_argument("-ht", "--height", default=456, type=int, help="height for the image")
ap.add_argument("-b", "--batch_size", default=32, type=int, help="batch size")
ap.add_argument("-e", "--epochs", default=500, type=int, help="number of epochs")
ap.add_argument("-v", "--validation_split", default=0.2, type=float, help="validation split")
ap.add_argument("-s", "--seed", default=168369, type=int, help="seed")
args = vars(ap.parse_args())
recreate = args["recreate"]
threshold = args["threshold"]
data_path = args["data"]
test_data_path = args["test"]
img_width, img_height = args["width"], args["height"] #has to be 32x32 because of the loss function
image_size = (img_width, img_height)
batch_size = args["batch_size"]
epochs = args["epochs"]
validation_split = args["validation_split"]
seed = args["seed"]
verbosity = 1
print(f'data_path {data_path}')
#todo load all images until numpy array is created via iiterating through batch
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_path,
labels=None,
color_mode='grayscale',
validation_split=validation_split,
subset="training",
seed=seed,
image_size=image_size,
batch_size=batch_size, #load all
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_path,
labels=None,
color_mode='grayscale',
validation_split=validation_split,
subset="validation",
seed=seed,
image_size=image_size,
batch_size=batch_size,
)
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
test_data_path,
labels=None,
color_mode='grayscale',
seed=seed,
image_size=image_size,
batch_size=batch_size, #load all
)
print(train_ds)
n = next(train_ds.as_numpy_iterator()).astype('float32') / 255
print(len(n))
v = next(val_ds.as_numpy_iterator()).astype('float32') / 255
print(len(v))
t = next(test_ds.as_numpy_iterator()).astype('float32') / 255
print(len(t))
print('-----')
channels = 1 #1 grayscale, 3 rgb
latent_dim = 1024
encoder_net = tf.keras.Sequential(
[
InputLayer(input_shape=(32, 32, channels)),
Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu)
])
decoder_net = tf.keras.Sequential(
[
InputLayer(input_shape=(latent_dim,)),
Dense(4*4*128),
Reshape(target_shape=(4, 4, 128)),
Conv2DTranspose(256, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2DTranspose(64, 4, strides=2, padding='same', activation=tf.nn.relu),
Conv2DTranspose(channels, 4, strides=2, padding='same', activation='sigmoid')
])
filepath = f'checkpoints/od-vae-anomalies-{threshold:.4f}'
od = None
if not recreate:
try:
od = load_detector(filepath)
except:
print('creating new model, cause {filepath} could not be loaded')
if od == None:
od = OutlierVAE(threshold=threshold, # threshold for outlier score
score_type='mse', # use MSE of reconstruction error for outlier detection
encoder_net=encoder_net, # can also pass VAE model instead
decoder_net=decoder_net, # of separate encoder and decoder
latent_dim=latent_dim,
samples=2)
# train
od.fit(n,
loss_fn=elbo,
cov_elbo=dict(sim=.05),
epochs=epochs,
verbose=True)
# save the trained outlier detector
save_detector(od, filepath)
od_preds = od.predict(n,
outlier_type='instance', # use 'feature' or 'instance' level
return_feature_score=True, # scores used to determine outliers
return_instance_score=True)
print(list(od_preds['data'].keys()))
target = np.zeros(n.shape[0],).astype(int) # all normal CIFAR10 training instances
labels = ['normal', 'outlier']
plot_instance_score(od_preds, target, labels, od.threshold)
od_preds = od.predict(v,
outlier_type='instance', # use 'feature' or 'instance' level
return_feature_score=True, # scores used to determine outliers
return_instance_score=True)
print(list(od_preds['data'].keys()))
target = np.zeros(v.shape[0],).astype(int) # all normal CIFAR10 training instances
labels = ['normal', 'outlier']
plot_instance_score(od_preds, target, labels, od.threshold)
od_preds = od.predict(t,
outlier_type='instance', # use 'feature' or 'instance' level
return_feature_score=True, # scores used to determine outliers
return_instance_score=True)
print(list(od_preds['data'].keys()))
target = np.zeros(t.shape[0],).astype(int) # all normal CIFAR10 training instances
labels = ['normal', 'outlier']
plot_instance_score(od_preds, target, labels, od.threshold)
idx = 8
X = n[idx].reshape(1, 32, 32, 1)
X_recon = od.vae(X)
plt.imshow(X.reshape(32, 32, 1))
plt.axis('off')
plt.show()
plt.imshow(X_recon.numpy().reshape(32, 32, 1))
plt.axis('off')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment