Skip to content

Instantly share code, notes, and snippets.

@Cospel
Last active June 18, 2020 18:36
Show Gist options
  • Select an option

  • Save Cospel/96b88baa589e244b6bf53d47ccfb7a4c to your computer and use it in GitHub Desktop.

Select an option

Save Cospel/96b88baa589e244b6bf53d47ccfb7a4c to your computer and use it in GitHub Desktop.
attempt to create antiaaliasing cnn blurpool wrapper for existing keras application models - (tf2-keras)
import tensorflow as tf
import numpy as np
class BlurPool(tf.keras.layers.Layer):
"""
https://arxiv.org/abs/1904.11486
https://github.com/adobe/antialiased-cnns
https://github.com/adobe/antialiased-cnns/issues/10
"""
def __init__(self, filt_size=3, stride=2, **kwargs):
self.strides = (stride,stride)
self.filt_size = filt_size
self.padding = ( (int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ), (int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ) )
if(self.filt_size==1):
self.a = np.array([1.,])
elif(self.filt_size==2):
self.a = np.array([1., 1.])
elif(self.filt_size==3):
self.a = np.array([1., 2., 1.])
elif(self.filt_size==4):
self.a = np.array([1., 3., 3., 1.])
elif(self.filt_size==5):
self.a = np.array([1., 4., 6., 4., 1.])
elif(self.filt_size==6):
self.a = np.array([1., 5., 10., 10., 5., 1.])
elif(self.filt_size==7):
self.a = np.array([1., 6., 15., 20., 15., 6., 1.])
super(BlurPool, self).__init__(**kwargs)
def build(self, input_shape):
k = self.a
k = k[:,None]*k[None,:]
k = k / np.sum(k)
k = np.tile (k[:,:,None,None], (1,1,input_shape[-1],1))
self.kernel = tf.keras.backend.constant(k, dtype=tf.keras.backend.floatx())
def compute_output_shape(self, input_shape):
height = input_shape[1] // self.strides[0]
width = input_shape[2] // self.strides[1]
channels = input_shape[3]
return (input_shape[0], height, width, channels)
class ReluBlurPool(BlurPool):
def call(self, x):
x = tf.nn.relu(x)
x = tf.keras.backend.spatial_2d_padding(x, padding=self.padding)
x = tf.nn.depthwise_conv2d(x, self.kernel, strides=self.strides, padding='valid')
return x
class MaxBlurPool(BlurPool):
def call(self, x):
x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 1, 1, 1])
x = tf.keras.backend.spatial_2d_padding(x, padding=self.padding)
x = tf.nn.depthwise_conv2d(x, self.kernel, strides=self.strides, padding='valid')
return x
class AvgBlurPool(BlurPool):
def call(self, x):
x = tf.keras.backend.spatial_2d_padding(x, padding=self.padding)
x = tf.nn.depthwise_conv2d(x, self.kernel, strides=self.strides, padding='valid')
return x
class SumLayer(tf.keras.layers.Layer):
def __init__(self, inputs):
super(SumLayer, self).__init__()
self.inputs = inputs
def call(self, inputs):
output = self.inputs[0]
for i in range(1, len(self.inputs)):
output += self.inputs[i]
return output
class AntialiasingModel:
"""
AntialiasingWrapper
"""
def __init__(self, model, filt_size=3):
self.model = model
self.filt_size = filt_size
def __call__(self):
last_stride = 1
new_model = tf.keras.Sequential()
for i in range(len(self.model.layers)):
name = self.model.layers[i].name
name_type = type(self.model.layers[i]).__name__.lower()
try:
shape = self.model.layers[i].get_weights()[0].shape
last_stride = self.model.layers[i].strides[0]
if "conv2d" in name_type and last_stride == 2 and shape[:2] != (1,1):
shape = self.model.layers[i].get_weights()[0].shape
filters_out = shape[3]
filters_in = shape[2]
new_layer = tf.keras.layers.Conv2D(filters=filters_out, strides=(1,1), kernel_size=shape[:2], padding="valid", weights=[self.model.layers[i].get_weights()])
new_model.add(new_layer)
print(f"Replacing {name}, stride {last_stride}, with new Conv2D layer (with copied weights, stride 1)")
continue
except :
pass
if "relu" in name_type and last_stride == 2:
print(f"Replacing {name} with ReluBlurPool")
blayer = BlurPool(filt_size=self.filt_size, stride=2)
self.model.layers[i] = blayer
new_model.add(blayer)
elif "maxpooling2d" in name_type:
print(f"Replacing {name} with MaxBlurPool")
mlayer = MaxBlurPool(filt_size=self.filt_size, stride=2)
new_model.add(mlayer)
elif "avgpooling2d" in name_type:
print(f"Replacing {name} with AvgBlurPool")
alayer = AvgBlurPool(filt_size=self.filt_size, stride=2)
new_model.add(alayer)
elif "add" == name_type:
print("Creating new add layer...")
int_node = self.model.layers[i]._inbound_nodes
predecessor_layers = int_node[0].inbound_layers
outputs = [ layer.output for layer in predecessor_layers ]
slayer = SumLayer(inputs=outputs)
new_model.add(slayer)
else:
new_model.add(self.model.layers[i])
return new_model
def get_result(classifier, image, k=1):
result = classifier.predict(image[np.newaxis, ...])
predicted_class = np.argmax(result[0], axis=-1)
predicted_class_name = imagenet_labels[predicted_class]
#print(predicted_class_name)
for index in result[0].argsort()[-k:][::-1]:
print(imagenet_labels[index], result[0][index])
if __name__ == "__main__":
import numpy as np
import PIL.Image as Image
import imgaug.augmenters as iaa
import imgaug as ia
import os
import cv2
IMAGE_SHAPE = (224, 224, 3)
classifier1 = tf.keras.applications.MobileNetV2(input_shape=IMAGE_SHAPE,
include_top=True,
weights='imagenet')
classifier2 = tf.keras.applications.MobileNetV2(input_shape=IMAGE_SHAPE,
include_top=True,
weights='imagenet')
classifier2 = AntialiasingModel(classifier2, filt_size=5)()
classifier2.summary()
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
image = tf.keras.utils.get_file('image2.jpg','https://upload.wikimedia.org/wikipedia/commons/6/66/An_up-close_picture_of_a_curious_male_domestic_shorthair_tabby_cat.jpg')
image = Image.open(image).resize(IMAGE_SHAPE[:2])
seq = iaa.Sequential([
iaa.CropAndPad(
px=(1, 16),
pad_mode=ia.ALL,
pad_cval=(0, 255))
])
os.makedirs("images", exist_ok=True)
for i in range(100):
print(i)
image_s = seq.augment_images([np.array(image)])[0]
image_n = tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(np.array(image_s), tf.float32))
get_result(classifier1, image_n, k=1)
get_result(classifier2, image_n, k=1)
cv2.imwrite("images/" + str(i) + ".jpg", image_s)
print('-----')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment