Last active
May 25, 2020 02:54
-
-
Save nwatab/923c43d521223d74c8b5e055bc34309f to your computer and use it in GitHub Desktop.
UNet implementation of Matlab sample for semantic segmentation https://jp.mathworks.com/help/images/multispectral-semantic-segmentation-using-deep-learning.html?lang=en . Outputs are made on different hyperparameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import imageio | |
import numpy as np | |
import tensorflow as tf | |
from keras.callbacks import ModelCheckpoint, Callback | |
from keras import optimizers | |
import keras.backend as K | |
import matplotlib.pyplot as plt | |
from models import Pix2Pix, SegNet, vgg19_unet, UNetMatlab | |
np.set_printoptions(threshold=64**4, linewidth=300) | |
def random_crop(image, top, left, crop_size): | |
bottom = top + crop_size[0] | |
right = left + crop_size[1] | |
image = image[top:bottom, left:right, :] | |
return image | |
def get_datagen(img_path, seg_path, img_size=(256, 256), batch_size=16, train=True, sample_weights=None): | |
img = imageio.imread(img_path, pilmode='RGB') | |
seg = imageio.imread(seg_path, pilmode='RGB') | |
seg_temp = np.copy(seg[400:2000, 270:2200, :]) | |
img = img[400:2000, 270:2200, :] | |
seg = seg[400:2000, 270:2200, :2] | |
seg[:, :, 1] = 255 - seg[:, :, 0] # Index 0: Defect, Index 1: Background | |
img = img.astype(np.float) | |
seg = seg.astype(np.float) | |
img /= 255. | |
seg /= 255. | |
imgs = [] | |
segs = [] | |
h, w, _ = img.shape | |
while True: | |
# Crop | |
top = np.random.randint(0, h - img_size[0]) | |
left = np.random.randint(0, w - img_size[1]) | |
cropped_img = random_crop(img, top, left, img_size) | |
cropped_seg = random_crop(seg, top, left, img_size) | |
# Horizontal Flip | |
if np.random.rand() > 0.5 and train: | |
cropped_img = cropped_img[:, ::-1, :] | |
cropped_seg = cropped_seg[:, ::-1, :] | |
# Vertical Flip | |
if np.random.rand() > 0.5 and train: | |
cropped_img = cropped_img[::-1, :, :] | |
cropped_seg = cropped_seg[::-1, :, :] | |
# Noise | |
if train: | |
noise = 0.001 * np.random.randn(*cropped_img.shape) | |
cropped_img += noise | |
imgs.append(cropped_img) | |
segs.append(cropped_seg) | |
if len(imgs) == batch_size: | |
imgs_temp = np.array(imgs) | |
segs_temp = np.array(segs) | |
imgs = [] | |
segs = [] | |
if sample_weights is not None: | |
yield (imgs_temp, segs_temp, sample_weights) | |
yield (imgs_temp, segs_temp) | |
def decode_img(x): | |
x *= 255 | |
x = x.astype(np.uint8) | |
return x | |
def decode_onehot(y): | |
y = decode_img(y) | |
zero_channel = np.zeros((*y.shape[:-1], 1), dtype=np.uint8) | |
y = np.concatenate((y, zero_channel), axis=-1) | |
y[:, :, :, 1] = 0 | |
return y | |
def convert_prob_into_onehot(x): | |
t = tf.constant(value=x) | |
y = tf.one_hot(tf.argmax(t, dimension = -1), depth = 2) | |
return y.eval() | |
def weighted_crossentropy_wrapper(class_weights): | |
def weighted_cross_entropy(onehot_labels, output): | |
''' | |
A quick wrapper to compute weighted cross entropy. | |
------------------ | |
Technical Details | |
------------------ | |
The class_weights list can be multiplied by onehot_labels directly because the last dimension | |
of onehot_labels is 12 and class_weights (length 12) can broadcast across that dimension, which is what we want. | |
Then we collapse the last dimension for the class_weights to get a shape of (batch_size, height, width, 1) | |
to get a mask with each pixel's value representing the class_weight. | |
This mask can then be that can be broadcasted to the intermediate output of logits | |
and onehot_labels when calculating the cross entropy loss. | |
------------------ | |
INPUTS: | |
- onehot_labels(Tensor): the one-hot encoded labels of shape (batch_size, height, width, num_classes) | |
- logits(Tensor): the logits output from the model that is of shape (batch_size, height, width, num_classes) | |
- class_weights(list): A list where each index is the class label and the value of the index is the class weight. | |
OUTPUTS: | |
- loss(Tensor): a scalar Tensor that is the weighted cross entropy loss output. | |
''' | |
# weights = onehot_labels * class_weights + (1 - onehot_labels) | |
# weights = tf.reduce_sum(weights, 3) | |
# logits = convert_to_logits(prob) | |
loss = -tf.reduce_mean(onehot_labels * weights * tf.log(output) + 1e-9) | |
# loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits, weights=weights) | |
# loss = tf.reduce_mean(loss_batches) | |
return loss | |
return weighted_cross_entropy | |
class ImageWriter(Callback): | |
def __init__(self, img_shape, batch_size): | |
super().__init__() | |
self.batch_size = batch_size | |
test_gen = get_datagen('img91.png', 'seg91.png', train=False, batch_size=batch_size, img_size=img_shape) | |
self.x, self.y = test_gen.__next__() | |
self.y_shape = (self.x.shape[1], self.x.shape[2], 2) | |
self.img = decode_img(self.x) | |
self.gth = decode_onehot(self.y) | |
self.preds = [] | |
def on_epoch_end(self, epoch, logs={}): | |
self.p = self.model.predict_on_batch(self.x) | |
self.pre = decode_onehot(self.p) | |
self.preds.append(self.pre) | |
figsize = ( | |
(self.x.shape[2] * (len(self.preds) + 1)) / 100, | |
(self.x.shape[1] * (self.batch_size + 1)) / 100 | |
) | |
fig, axes = plt.subplots(self.batch_size, 2 + len(self.preds), figsize=figsize) | |
# Set title | |
axes[0, 0].set_title('X') | |
axes[0, 1].set_title('GT') | |
for i in range(len(self.preds)): | |
axes[0, i + 2].set_title(str(i)) | |
# Set images | |
for i in range(self.batch_size): | |
axes[i, 0].imshow(self.img[i], vmin=0, vmax=255) | |
axes[i, 0].axis('off') | |
axes[i, 1].imshow(self.gth[i], vmin=0, vmax=255) | |
axes[i, 1].axis('off') | |
for j in range(len(self.preds)): | |
axes[i, j + 2].imshow(self.preds[j][i], vmin=0, vmax=255) | |
axes[i, j + 2].axis('off') | |
plt.savefig('history.jpg'.format(epoch)) | |
if __name__ =='__main__': | |
img_shape = (256, 256, 3) | |
steps_per_epoch = 128 | |
validation_steps = 4 | |
epochs = 50 | |
batch_size = 16 | |
weight_decay_l2 = 0.01 | |
train_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=True) | |
test_gen = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=batch_size, sample_weights=None, train=False) | |
# Calculate class weights | |
_, y = get_datagen('img91.png', 'seg91.png', img_size=img_shape, batch_size=1024, sample_weights=None, train=False).__next__() | |
pixcount = np.count_nonzero(y, axis=(0,1,2)) | |
imgcount = np.count_nonzero(np.count_nonzero(y, axis=(1, 2)), axis=0) | |
freq = pixcount / imgcount | |
weights = 1. / freq | |
weights /= weights.sum() | |
print('weights =', weights) | |
# model = vgg19_unet(input_shape=img_shape, classes=2, weight_decay=weight_decay_l2) | |
model = Pix2Pix(input_shape=img_shape, classes=2).build() | |
# model = SegNet(input_shape=img_shape, classes=2) | |
# model = UNetMatlab(input_shape=img_shape, classes=2).build() | |
model.compile( | |
# optimizer=optimizers.SGD(lr=5e-2, momentum=0.9, clipnorm=0.05), | |
optimizer=optimizers.Adam(lr=1e-4, clipnorm=0.05), | |
loss=weighted_crossentropy_wrapper(weights), | |
metrics=['accuracy'] | |
) | |
model.summary() | |
mc_cb = ModelCheckpoint('model.h5', monitor='val_loss') | |
im_cb = ImageWriter(img_shape, 32) | |
history = model.fit_generator( | |
generator=train_gen, | |
steps_per_epoch=steps_per_epoch, | |
epochs=epochs, | |
callbacks=[mc_cb, im_cb], | |
validation_data=test_gen, | |
validation_steps=validation_steps, | |
shuffle=True, | |
use_multiprocessing=True | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
import skimage.io as io | |
import skimage.transform as trans | |
import numpy as np | |
from keras.engine import InputSpec | |
from keras import initializers, regularizers | |
from keras.layers import Input, Concatenate, BatchNormalization, Activation, MaxPooling2D, Dropout, Conv2DTranspose | |
from keras.layers.advanced_activations import LeakyReLU, ReLU | |
from keras.layers.convolutional import UpSampling2D, Conv2D | |
from keras.models import Model | |
import keras.backend as K | |
import tensorflow as tf | |
class UNetMatlab: | |
""" https://jp.mathworks.com/help/images/multispectral-semantic-segmentation-using-deep-learning.html?lang=en """ | |
def __init__(self, input_shape, classes, l2reg=0.0001): | |
self.input_shape = input_shape | |
self.classes = classes | |
self.l2reg = l2reg | |
def build(self): | |
x = Input(shape=self.input_shape) | |
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(x) | |
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
d1 = h | |
h = MaxPooling2D(2)(h) | |
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu')(h) | |
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
d2 = h | |
h = MaxPooling2D(2)(h) | |
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
d3 = h | |
h = MaxPooling2D(2)(h) | |
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
d4 = h | |
h = Dropout(0.5)(h) | |
h = MaxPooling2D(2)(h) | |
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Conv2D(1024, kernel_size=3, padding='same', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Dropout(0.5)(h) | |
h = Conv2DTranspose(512, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Concatenate(axis=-1)([h, d4]) | |
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2D(512, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2DTranspose(256, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Concatenate(axis=-1)([h, d3]) | |
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2D(256, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2DTranspose(128, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Concatenate(axis=-1)([h, d2]) | |
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2D(128, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2DTranspose(64, 2, strides=2, padding='valid', kernel_initializer='he_normal', activation='relu', kernel_regularizer=regularizers.l2(self.l2reg))(h) | |
h = Concatenate(axis=-1)([h, d1]) | |
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
h = Conv2D(64, kernel_size=3, padding='same', kernel_initializer='he_normal')(h) | |
logit = Conv2D(self.classes, kernel_size=1, padding='valid', kernel_initializer='he_normal')(h) | |
prob = Activation('softmax')(logit) | |
model = Model(x, prob) | |
return model | |
""" https://github.com/eriklindernoren/Keras-GAN/blob/master/pix2pix/pix2pix.py """ | |
class Pix2Pix: | |
def __init__(self, input_shape, classes): | |
self.input_shape = input_shape | |
self.classes = classes | |
def build(self): | |
def conv(layer_input, filters): | |
"""Layers used during downsampling""" | |
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(layer_input) | |
d = BatchNormalization(momentum=0.9)(d) | |
d = LeakyReLU(alpha=0.2)(d) | |
d = ConvSN2D(filters, kernel_size=3, strides=1, dilation_rate=2, padding='same')(d) | |
d = BatchNormalization(momentum=0.9)(d) | |
d = LeakyReLU(alpha=0.2)(d) | |
pooled = MaxPooling2D(2)(d) | |
return pooled, d | |
def deconv(layer_input, skip_input, filters): | |
"""Layers used during upsampling""" | |
u = UpSampling2D(size=2)(layer_input) | |
u = Concatenate(axis=-1)([u, skip_input]) | |
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u) | |
u = BatchNormalization(momentum=0.9)(u) | |
u = LeakyReLU(alpha=0.2)(u) | |
u = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(u) | |
u = BatchNormalization(momentum=0.9)(u) | |
u = LeakyReLU(alpha=0.2)(u) | |
return u | |
def res(layer_input, filters): | |
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(layer_input) | |
x = BatchNormalization(momentum=0.9)(x) | |
x = LeakyReLU(alpha=0.2)(x) | |
x = ConvSN2D(filters, kernel_size=3, strides=1, padding='same')(x) | |
x = BatchNormalization(momentum=0.9)(x) | |
x = LeakyReLU(alpha=0.2)(x) | |
return x | |
x = Input(shape=self.input_shape) | |
p1, d1 = conv(x, 64) | |
p2, d2 = conv(p1, 128) | |
p3, d3 = conv(p2, 256) | |
p4, d4 = conv(p3, 512) | |
p5, d5 = conv(p4, 512) | |
p6, d6 = conv(p5, 512) | |
p7, d7 = conv(p6, 1024) | |
z = res(p7, 1024) | |
u1 = deconv(z, d7, 512) | |
u2 = deconv(u1, d6, 512) | |
u3 = deconv(u2, d5, 512) | |
u4 = deconv(u3, d4, 256) | |
u5 = deconv(u4, d3, 128) | |
u6 = deconv(u5, d2, 64) | |
u7 = deconv(u6, d1, 64) | |
logit = ConvSN2D(self.classes, kernel_size=1)(u7) | |
prob = Activation('softmax')(logit) | |
return Model(x, prob) | |
def vgg19_unet(input_shape, weight_decay=0., classes=2): | |
# Image Input | |
img = Input(shape=input_shape, name='image') | |
# Block 1 | |
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', kernel_regularizer=regularizers.l2(weight_decay))(img) | |
conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv1) | |
conv1 = BatchNormalization()(conv1) | |
pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(conv1) | |
# Block 2 | |
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool1) | |
conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv2) | |
conv2 = BatchNormalization()(conv2) | |
pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(conv2) | |
# Block 3 | |
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool2) | |
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv3) | |
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv3) | |
conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv3) | |
conv3 = BatchNormalization()(conv3) | |
pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(conv3) | |
# Block 4 | |
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool3) | |
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv4) | |
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv4) | |
conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv4) | |
conv4 = BatchNormalization()(conv4) | |
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4) | |
# Block 5 | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', kernel_regularizer=regularizers.l2(weight_decay))(pool4) | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', kernel_regularizer=regularizers.l2(weight_decay))(conv5) | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', kernel_regularizer=regularizers.l2(weight_decay))(conv5) | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4', kernel_regularizer=regularizers.l2(weight_decay))(conv5) | |
conv5 = BatchNormalization()(conv5) | |
up6 = UpSampling2D(2)(conv5) | |
up6 = Concatenate(axis=-1)([up6, conv4]) | |
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv1')(up6) | |
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv2')(conv6) | |
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv3')(conv6) | |
conv6 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block6_conv4')(conv6) | |
conv6 = BatchNormalization()(conv6) | |
up7 = UpSampling2D(2)(conv6) | |
up7 = Concatenate(axis=-1)([up7, conv3]) | |
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv1')(up7) | |
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv2')(conv7) | |
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv3')(conv7) | |
conv7 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block7_conv4')(conv7) | |
conv7 = BatchNormalization()(conv7) | |
up8 = UpSampling2D(2)(conv7) | |
up8 = Concatenate(axis=-1)([up8, conv2]) | |
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv1')(up8) | |
conv8 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block8_conv2')(conv8) | |
conv8 = BatchNormalization()(conv8) | |
up9 = UpSampling2D(2)(conv8) | |
up9 = Concatenate(axis=-1)([up9, conv1]) | |
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv1')(up9) | |
conv9 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block9_conv2')(conv9) | |
conv9 = BatchNormalization()(conv9) | |
output = Conv2D(classes, (1, 1), padding='same', activation='softmax', name="prob")(conv9) | |
model = Model(inputs=img, outputs=output) | |
from keras.regularizers import l1, l2 | |
from keras.applications.vgg19 import VGG19 | |
weights_path = 'temp_vgg19_notop.h5' | |
VGG19(input_shape=input_shape, weights='imagenet', include_top=False).save_weights(weights_path) | |
model.load_weights(weights_path, by_name=True) | |
import os; os.remove('temp_vgg19_notop.h5') | |
return model | |
def SegNet(input_shape=(360, 480, 3), classes=12): | |
### @ https://github.com/alexgkendall/SegNet-Tutorial/blob/master/Example_Models/bayesian_segnet_camvid.prototxt | |
img_input = Input(shape=input_shape) | |
x = img_input | |
# Encoder | |
x = Conv2D(64, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = MaxPooling2D(pool_size=(2, 2))(x) | |
x = Conv2D(128, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = MaxPooling2D(pool_size=(2, 2))(x) | |
x = Conv2D(256, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = MaxPooling2D(pool_size=(2, 2))(x) | |
x = Conv2D(512, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
# Decoder | |
x = Conv2D(512, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = UpSampling2D(size=(2, 2))(x) | |
x = Conv2D(256, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = UpSampling2D(size=(2, 2))(x) | |
x = Conv2D(128, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = UpSampling2D(size=(2, 2))(x) | |
x = Conv2D(64, (3, 3), padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
x = Conv2D(classes, (1, 1), padding="valid")(x) | |
x = Activation("softmax")(x) | |
model = Model(img_input, x) | |
return model | |
""" https://github.com/IShengFang/SpectralNormalizationKeras/blob/master/SpectralNormalizationKeras.py """ | |
class ConvSN2D(Conv2D): | |
def build(self, input_shape): | |
if self.data_format == 'channels_first': | |
channel_axis = 1 | |
else: | |
channel_axis = -1 | |
if input_shape[channel_axis] is None: | |
raise ValueError('The channel dimension of the inputs ' | |
'should be defined. Found `None`.') | |
input_dim = input_shape[channel_axis] | |
kernel_shape = self.kernel_size + (input_dim, self.filters) | |
self.kernel = self.add_weight(shape=kernel_shape, | |
initializer=self.kernel_initializer, | |
name='kernel', | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
if self.use_bias: | |
self.bias = self.add_weight(shape=(self.filters,), | |
initializer=self.bias_initializer, | |
name='bias', | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
else: | |
self.bias = None | |
self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]), | |
initializer=initializers.RandomNormal(0, 1), | |
name='sn', | |
trainable=False) | |
# Set input spec. | |
self.input_spec = InputSpec(ndim=self.rank + 2, | |
axes={channel_axis: input_dim}) | |
self.built = True | |
def call(self, inputs, training=None): | |
def _l2normalize(v, eps=1e-12): | |
return v / (K.sum(v ** 2) ** 0.5 + eps) | |
def power_iteration(W, u): | |
#Accroding the paper, we only need to do power iteration one time. | |
_u = u | |
_v = _l2normalize(K.dot(_u, K.transpose(W))) | |
_u = _l2normalize(K.dot(_v, W)) | |
return _u, _v | |
#Spectral Normalization | |
W_shape = self.kernel.shape.as_list() | |
#Flatten the Tensor | |
W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]]) | |
_u, _v = power_iteration(W_reshaped, self.u) | |
#Calculate Sigma | |
sigma=K.dot(_v, W_reshaped) | |
sigma=K.dot(sigma, K.transpose(_u)) | |
#normalize it | |
W_bar = W_reshaped / sigma | |
#reshape weight tensor | |
if training in {0, False}: | |
W_bar = K.reshape(W_bar, W_shape) | |
else: | |
with tf.control_dependencies([self.u.assign(_u)]): | |
W_bar = K.reshape(W_bar, W_shape) | |
outputs = K.conv2d( | |
inputs, | |
W_bar, | |
strides=self.strides, | |
padding=self.padding, | |
data_format=self.data_format, | |
dilation_rate=self.dilation_rate) | |
if self.use_bias: | |
outputs = K.bias_add( | |
outputs, | |
self.bias, | |
data_format=self.data_format) | |
if self.activation is not None: | |
return self.activation(outputs) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment