Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active June 28, 2018 08:15
Show Gist options
  • Save bzamecnik/c2dfe5772ba3cd16c1be17ba42b7db66 to your computer and use it in GitHub Desktop.
Save bzamecnik/c2dfe5772ba3cd16c1be17ba42b7db66 to your computer and use it in GitHub Desktop.
Reshaping input data for convolution in Keras
# In Keras the Convolution layer requirest an additional dimension which will be used for the various filter.
# When we have eg. 2D dataset the shape is (data_points, rows, cols).
# But Convolution2D requires shape (data_points, rows, cols, 1).
# Otherwise it fails with eg. "Exception: Input 0 is incompatible with layer convolution2d_5: expected ndim=4, found ndim=3"
#
# Originally I reshaped the data beforehand but it only complicates things.
#
# An easier and more elegant solution is to add a Reshape layer at the input
# of the network!
#
# Docs: https://keras.io/layers/core/#reshape
from keras.models import Sequential, Model
from keras.layers import Input
from keras.layers.core import Activation, Reshape
from keras.layers.convolutional import Convolution2D
# eg. 100x100 px images
input_shape = (100, 100)
def create_model_sequential(input_shape):
"""For the classic sequential API..."""
model = Sequential()
# add one more dimension for convolution
model.add(Reshape(input_shape + (1, ), input_shape=input_shape))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
# ...
return model
def create_model_functional(input_shape):
"""For the new functional API..."""
inputs = Input(input_shape)
# add one more dimension for convolution
x = Reshape(input_shape + (1, ), input_shape=input_shape)(inputs)
x = Convolution2D(32, 3, 3)(x)
x = Activation('relu')(x)
# ...
return Model(inputs, x)
@parth126
Copy link

@bzamecnik: I tried doing this with the functional api, but the Reshape function throws up an error
ValueError: total size of new array must be unchanged

It is same if I use x = Reshape((1,x,y), input_shape=(x,y))(inputs)

Both these variants work well with the sequential model. Any idea what the problem might be?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment