Created
December 26, 2021 20:14
-
-
Save manzke/8fed9a1887b53ac4458fa218d4197757 to your computer and use it in GitHub Desktop.
Variational AutoEncoder with Alibi-Detect and Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
tf.keras.backend.clear_session() | |
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, \ | |
Dense, Layer, Reshape, InputLayer, Flatten | |
from tqdm import tqdm | |
from tensorflow import keras | |
from keras import layers | |
from keras.callbacks import EarlyStopping, ModelCheckpoint | |
import argparse | |
from time import time | |
print(tf.__version__) | |
from alibi_detect.od import OutlierVAE | |
from alibi_detect.utils.fetching import fetch_detector | |
from alibi_detect.utils.perturbation import apply_mask | |
from alibi_detect.utils.saving import save_detector, load_detector | |
from alibi_detect.utils.visualize import plot_instance_score, plot_feature_outlier_image | |
from alibi_detect.models.tensorflow.losses import elbo | |
logger = tf.get_logger() | |
#logger.setLevel(logging.ERROR) | |
ap = argparse.ArgumentParser() | |
ap.add_argument("-d", "--data", required=True, help="path to the data used for classification") | |
ap.add_argument("-t", "--test", required=True, help="path to the test data used for classification") | |
ap.add_argument("--recreate", default=False, type=bool, help="path to the test data used for classification") | |
ap.add_argument("--threshold", default=0.005, type=float, help="threshold for outliers") | |
ap.add_argument("-w", "--width", default=456, type=int, help="width for the image") | |
ap.add_argument("-ht", "--height", default=456, type=int, help="height for the image") | |
ap.add_argument("-b", "--batch_size", default=32, type=int, help="batch size") | |
ap.add_argument("-e", "--epochs", default=500, type=int, help="number of epochs") | |
ap.add_argument("-v", "--validation_split", default=0.2, type=float, help="validation split") | |
ap.add_argument("-s", "--seed", default=168369, type=int, help="seed") | |
args = vars(ap.parse_args()) | |
recreate = args["recreate"] | |
threshold = args["threshold"] | |
data_path = args["data"] | |
test_data_path = args["test"] | |
img_width, img_height = args["width"], args["height"] #has to be 32x32 because of the loss function | |
image_size = (img_width, img_height) | |
batch_size = args["batch_size"] | |
epochs = args["epochs"] | |
validation_split = args["validation_split"] | |
seed = args["seed"] | |
verbosity = 1 | |
print(f'data_path {data_path}') | |
#todo load all images until numpy array is created via iiterating through batch | |
train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_path, | |
labels=None, | |
color_mode='grayscale', | |
validation_split=validation_split, | |
subset="training", | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, #load all | |
) | |
val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_path, | |
labels=None, | |
color_mode='grayscale', | |
validation_split=validation_split, | |
subset="validation", | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, | |
) | |
test_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
test_data_path, | |
labels=None, | |
color_mode='grayscale', | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, #load all | |
) | |
print(train_ds) | |
n = next(train_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(n)) | |
v = next(val_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(v)) | |
t = next(test_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(t)) | |
print('-----') | |
channels = 1 #1 grayscale, 3 rgb | |
latent_dim = 1024 | |
encoder_net = tf.keras.Sequential( | |
[ | |
InputLayer(input_shape=(32, 32, channels)), | |
Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu) | |
]) | |
decoder_net = tf.keras.Sequential( | |
[ | |
InputLayer(input_shape=(latent_dim,)), | |
Dense(4*4*128), | |
Reshape(target_shape=(4, 4, 128)), | |
Conv2DTranspose(256, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2DTranspose(64, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2DTranspose(channels, 4, strides=2, padding='same', activation='sigmoid') | |
]) | |
filepath = f'checkpoints/od-vae-anomalies-{threshold:.4f}' | |
od = None | |
if not recreate: | |
try: | |
od = load_detector(filepath) | |
except: | |
print('creating new model, cause {filepath} could not be loaded') | |
if od == None: | |
od = OutlierVAE(threshold=threshold, # threshold for outlier score | |
score_type='mse', # use MSE of reconstruction error for outlier detection | |
encoder_net=encoder_net, # can also pass VAE model instead | |
decoder_net=decoder_net, # of separate encoder and decoder | |
latent_dim=latent_dim, | |
samples=2) | |
# train | |
od.fit(n, | |
loss_fn=elbo, | |
cov_elbo=dict(sim=.05), | |
epochs=epochs, | |
verbose=True) | |
# save the trained outlier detector | |
save_detector(od, filepath) | |
od_preds = od.predict(n, | |
outlier_type='instance', # use 'feature' or 'instance' level | |
return_feature_score=True, # scores used to determine outliers | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(n.shape[0],).astype(int) # all normal CIFAR10 training instances | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) | |
od_preds = od.predict(v, | |
outlier_type='instance', # use 'feature' or 'instance' level | |
return_feature_score=True, # scores used to determine outliers | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(v.shape[0],).astype(int) # all normal CIFAR10 training instances | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) | |
od_preds = od.predict(t, | |
outlier_type='instance', # use 'feature' or 'instance' level | |
return_feature_score=True, # scores used to determine outliers | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(t.shape[0],).astype(int) # all normal CIFAR10 training instances | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) | |
idx = 8 | |
X = n[idx].reshape(1, 32, 32, 1) | |
X_recon = od.vae(X) | |
plt.imshow(X.reshape(32, 32, 1)) | |
plt.axis('off') | |
plt.show() | |
plt.imshow(X_recon.numpy().reshape(32, 32, 1)) | |
plt.axis('off') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment