Created
December 26, 2021 19:58
-
-
Save manzke/60ab4e8bb50fd591df9dca713c331023 to your computer and use it in GitHub Desktop.
AutoEncoder Example with Keras and Alibi-Detect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import logging | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
tf.keras.backend.clear_session() | |
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, \ | |
Dense, Layer, Reshape, InputLayer, Flatten | |
from tensorflow import keras | |
from keras import layers | |
from keras.callbacks import EarlyStopping, ModelCheckpoint | |
import argparse | |
from time import time | |
print(tf.__version__) | |
from alibi_detect.od import OutlierAE | |
from alibi_detect.utils.fetching import fetch_detector | |
from alibi_detect.utils.perturbation import apply_mask | |
from alibi_detect.utils.saving import save_detector, load_detector | |
from alibi_detect.utils.visualize import plot_instance_score, plot_feature_outlier_image | |
logger = tf.get_logger() | |
logger.setLevel(logging.ERROR) | |
ap = argparse.ArgumentParser() | |
ap.add_argument("-d", "--data", required=True, help="path to the data used for classification") | |
ap.add_argument("-t", "--test", required=True, help="path to the test data used for classification") | |
ap.add_argument("--recreate", default=False, type=bool, help="path to the test data used for classification") | |
ap.add_argument("--threshold", default=0.005, type=float, help="threshold for outliers") | |
ap.add_argument("-w", "--width", default=456, type=int, help="width for the image") | |
ap.add_argument("-ht", "--height", default=456, type=int, help="height for the image") | |
ap.add_argument("-b", "--batch_size", default=32, type=int, help="batch size") | |
ap.add_argument("-e", "--epochs", default=500, type=int, help="number of epochs") | |
ap.add_argument("-v", "--validation_split", default=0.2, type=float, help="validation split") | |
ap.add_argument("-s", "--seed", default=168369, type=int, help="seed") | |
args = vars(ap.parse_args()) | |
recreate = args["recreate"] | |
threshold = args["threshold"] | |
data_path = args["data"] | |
test_data_path = args["test"] | |
img_width, img_height = args["width"], args["height"] #has to be 32x32 because of the loss function | |
image_size = (img_width, img_height) | |
batch_size = args["batch_size"] | |
epochs = args["epochs"] | |
validation_split = args["validation_split"] | |
seed = args["seed"] | |
verbosity = 1 | |
print(f'data_path {data_path}') | |
#todo load all images until numpy array is created via iiterating through batch | |
train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_path, | |
labels=None, | |
color_mode='grayscale', | |
validation_split=validation_split, | |
subset="training", | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, #load all | |
) | |
val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_path, | |
labels=None, | |
color_mode='grayscale', | |
validation_split=validation_split, | |
subset="validation", | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, | |
) | |
test_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
test_data_path, | |
labels=None, | |
color_mode='grayscale', | |
seed=seed, | |
image_size=image_size, | |
batch_size=batch_size, #load all | |
) | |
print(train_ds) | |
n = next(train_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(n)) | |
v = next(val_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(v)) | |
t = next(test_ds.as_numpy_iterator()).astype('float32') / 255 | |
print(len(t)) | |
print('-----') | |
encoding_dim = 1024 | |
channels = 1 #1 grayscale, 3 rgb | |
encoder_net = tf.keras.Sequential( | |
[ | |
InputLayer(input_shape=(32, 32, channels)), | |
Conv2D(64, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2D(128, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2D(512, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Flatten(), | |
Dense(encoding_dim,) | |
]) | |
decoder_net = tf.keras.Sequential( | |
[ | |
InputLayer(input_shape=(encoding_dim,)), | |
Dense(4*4*128), | |
Reshape(target_shape=(4, 4, 128)), | |
Conv2DTranspose(256, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2DTranspose(64, 4, strides=2, padding='same', activation=tf.nn.relu), | |
Conv2DTranspose(channels, 4, strides=2, padding='same', activation='sigmoid') | |
]) | |
filepath = f'checkpoints/od-ae-anomalies-{threshold:.4f}' | |
od = None | |
if not recreate: | |
try: | |
od = load_detector(filepath) | |
except: | |
print('creating new model, cause {filepath} could not be loaded') | |
if od == None: | |
# initialize outlier detector | |
od = OutlierAE(threshold=threshold, # threshold for outlier score | |
encoder_net=encoder_net, # can also pass AE model instead | |
decoder_net=decoder_net, # of separate encoder and decoder | |
) | |
# train | |
od.fit(n, | |
epochs=epochs, | |
verbose=True) | |
# save the trained outlier detector | |
save_detector(od, filepath) | |
od_preds = od.predict(n, | |
outlier_type='instance', | |
return_feature_score=True, | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(n.shape[0],).astype(int) | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) | |
od_preds = od.predict(v, | |
outlier_type='instance', | |
return_feature_score=True, | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(v.shape[0],).astype(int) | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) | |
od_preds = od.predict(t, | |
outlier_type='instance', | |
return_feature_score=True, | |
return_instance_score=True) | |
print(list(od_preds['data'].keys())) | |
target = np.zeros(t.shape[0],).astype(int) | |
labels = ['normal', 'outlier'] | |
plot_instance_score(od_preds, target, labels, od.threshold) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment