Last active
February 18, 2020 08:10
-
-
Save kmader/0cc8d40216349dadc3e787adca2bbe77 to your computer and use it in GitHub Desktop.
Convert a preprocessing function into a convolutional layer (in Keras). It takes an input from -127 to 127 and runs it through the preprocess_input function and then returns a convolutional layer with the appropriate weights and biases to reproduce (as closely as possible) the preprocessing step. This allows models to be packaged without worryin…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input | |
from keras import layers, models | |
from sklearn.linear_model import LinearRegression | |
import numpy as np | |
def prep_to_conv( | |
in_prep_func: Callable[[np.array], np.array], | |
*, | |
min_val: float=-127, | |
max_val: float=127, | |
channels: int=3, | |
verbose: bool=False | |
) -> layers.Layer: | |
"""Function to turn a preprocessing step into a convolutional layer""" | |
test_channel = np.linspace(min_val, max_val, 9).reshape((1, 3, 3)).astype('float32') | |
test_input = np.stack([test_channel]*3, axis=-1) | |
test_output = in_prep_func(test_input.copy()) | |
W = np.zeros((1, 1, channels, channels), dtype='float32') | |
b = np.zeros((channels,), dtype='float32') | |
for i in range(channels): | |
x = test_input[:, :, :, i].ravel().reshape((-1, 1)) | |
y = test_output[:, :, :, i].ravel().reshape((-1, 1)) | |
lin_reg = LinearRegression() | |
lin_reg.fit(x, y) | |
W[0, 0, i, i] = lin_reg.coef_[0] | |
b[i] = lin_reg.intercept_ | |
conv1 = layers.Conv2D(channels, | |
kernel_size=(1,1), | |
activation='linear', | |
use_bias=True, | |
weights=[W,b], | |
input_shape=(None, None, channels)) | |
conv1.trainable=False | |
if verbose: | |
import matplotlib.pyplot as plt | |
# check the inputs and outputs | |
s = models.Sequential() | |
s.add(conv1) | |
pred_output = s.predict(test_input) | |
fig, m_axs = plt.subplots(1, channels, figsize=(4*channels, 4)) | |
for i, c_ax in enumerate(m_axs): | |
x = test_input[:, :, :, i].ravel().reshape((-1, 1)) | |
y = test_output[:, :, :, i].ravel().reshape((-1, 1)) | |
z = pred_output[:, :, :, i].ravel().reshape((-1, 1)) | |
c_ax.plot(x, y, 's', label='Given') | |
c_ax.plot(x, z, '-', label='Predicted') | |
c_ax.legend() | |
return conv1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment