Skip to content

Instantly share code, notes, and snippets.

@fabsta
Last active November 26, 2017 17:23
Show Gist options
  • Save fabsta/ec98f8f9cda7df26ba1936e4083c6ad0 to your computer and use it in GitHub Desktop.
Save fabsta/ec98f8f9cda7df26ba1936e4083c6ad0 to your computer and use it in GitHub Desktop.
[Reduce Overfitting] #deeplearning

About data augmentation

Keras comes with very convenient features for automating data augmentation. You simply define what types and maximum amounts of augmentation you want, and keras ensures that every item of every batch randomly is changed according to these settings. Here's how to define a generator that includes data augmentation: In [26]:

dim_ordering='tf' uses tensorflow dimension ordering, which is the same order as matplotlib uses for display. Therefore when just using for display purposes, this is more convenient

gen = image.ImageDataGenerator(rotation_range=10, width_shift_range=0.1, 
       height_shift_range=0.1, shear_range=0.15, zoom_range=0.1, 
       channel_shift_range=10., horizontal_flip=True, dim_ordering='tf')

Let's take a look at how this generator changes a single image (the details of this code don't matter much, but feel free to read the comments and keras docs to understand the details if you're interested). In [27]:

Create a 'batch' of a single image

img = np.expand_dims(ndimage.imread('data/dogscats/test/7.jpg'),0)

Request the generator to create batches from this image

aug_iter = gen.flow(img)

In [28]:

Get eight examples of these augmented images

aug_imgs = [next(aug_iter)[0].astype(np.uint8) for i in range(8)]

In [12]:

The original

plt.imshow(img[0])

Out[12]: <matplotlib.image.AxesImage at 0x7f303db3af10>

As you can see below, there's no magic to data augmentation - it's a very intuitive approach to generating richer input data. Generally speaking, your intuition should be a good guide to appropriate data augmentation, although it's a good idea to test your intuition by checking the results of different augmentation approaches. In [29]:

Augmented data

plots(aug_imgs, (20,7), 2)

In [22]:

Ensure that we return to theano dimension ordering

K.set_image_dim_ordering('th')

Adding data augmentation

Let's try adding a small amount of data augmentation, and see if we reduce overfitting as a result. The approach will be identical to the method we used to finetune the dense layers in lesson 2, except that we will use a generator with augmentation configured. Here's how we set up the generator, and create batches from it: In [14]:

gen = image.ImageDataGenerator(rotation_range=15, width_shift_range=0.1, 
                               height_shift_range=0.1, zoom_range=0.1, horizontal_flip=True)

In [15]:

batches = get_batches(path+'train', gen, batch_size=batch_size)

NB: We don't want to augment or shuffle the validation set

val_batches = get_batches(path+'valid', shuffle=False, batch_size=batch_size)

Found 23000 images belonging to 2 classes. Found 2000 images belonging to 2 classes. When using data augmentation, we can't pre-compute our convolutional layer features, since randomized changes are being made to every input image. That is, even if the training process sees the same image multiple times, each time it will have undergone different data augmentation, so the results of the convolutional layers will be different. Therefore, in order to allow data to flow through all the conv layers and our new dense layers, we attach our fully connected model to the convolutional model--after ensuring that the convolutional layers are not trainable: In [18]:

fc_model = get_fc_model()

In [19]:

for layer in conv_model.layers: layer.trainable = False

Look how easy it is to connect two models together!

conv_model.add(fc_model)

Now we can compile, train, and save our model as usual - note that we use fit_generator() since we want to pull random images from the directories on every batch. In [22]:

conv_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

In [52]:

conv_model.fit_generator(batches, samples_per_epoch=batches.nb_sample, nb_epoch=8, 
                        validation_data=val_batches, nb_val_samples=val_batches.nb_sample)
Epoch 1/8
23000/23000 [==============================] - 273s - loss: 0.3374 - acc: 0.9761 - val_loss: 0.2294 - val_acc: 0.9835
Epoch 2/8
23000/23000 [==============================] - 273s - loss: 0.2879 - acc: 0.9800 - val_loss: 0.2837 - val_acc: 0.9815
Epoch 3/8
23000/23000 [==============================] - 273s - loss: 0.2650 - acc: 0.9817 - val_loss: 0.2569 - val_acc: 0.9830
Epoch 4/8
23000/23000 [==============================] - 273s - loss: 0.2449 - acc: 0.9833 - val_loss: 0.3230 - val_acc: 0.9785
Epoch 5/8
23000/23000 [==============================] - 273s - loss: 0.2248 - acc: 0.9847 - val_loss: 0.2759 - val_acc: 0.9825
Epoch 6/8
23000/23000 [==============================] - 273s - loss: 0.2098 - acc: 0.9857 - val_loss: 0.2304 - val_acc: 0.9850
Epoch 7/8
23000/23000 [==============================] - 273s - loss: 0.2131 - acc: 0.9855 - val_loss: 0.2385 - val_acc: 0.9840
Epoch 8/8
23000/23000 [==============================] - 273s - loss: 0.2017 - acc: 0.9859 - val_loss: 0.2397 - val_acc: 0.9845

Out[52]: <keras.callbacks.History at 0x7f66ad01d7d0> In [24]:

conv_model.fit_generator(batches, samples_per_epoch=batches.nb_sample, nb_epoch=3, 
                        validation_data=val_batches, nb_val_samples=val_batches.nb_sample)
Epoch 1/3
23000/23000 [==============================] - 273s - loss: 0.2023 - acc: 0.9859 - val_loss: 0.2563 - val_acc: 0.9840
Epoch 2/3
23000/23000 [==============================] - 273s - loss: 0.1851 - acc: 0.9870 - val_loss: 0.2777 - val_acc: 0.9820
Epoch 3/3
23000/23000 [==============================] - 273s - loss: 0.1737 - acc: 0.9878 - val_loss: 0.2252 - val_acc: 0.9845

Out[24]: <keras.callbacks.History at 0x7f4f7962f650> In [53]:

conv_model.save_weights(model_path + 'aug1.h5')
conv_model.load_weights(model_path + 'aug1.h5')

Approaches to reducing overfitting

We do not necessarily need to rely on dropout or other regularization approaches to reduce overfitting. There are other techniques we should try first, since regularlization, by definition, biases our model towards simplicity - which we only want to do if we know that's necessary. This is the order that we recommend using for reducing overfitting (more details about each in a moment):

  • Add more data
  • Use data augmentation
  • Use architectures that generalize well
  • Add regularization (e.g. dropout, l1/l2 regularization)
  • Reduce architecture complexity (reduce filters -> hard to do).

We'll assume that you've already collected as much data as you can, so step (1) isn't relevant (this is true for most Kaggle competitions, for instance). So the next step (2) is data augmentation. This refers to creating additional synthetic data, based on reasonable modifications of your input data. For images, this is likely to involve one or more of: flipping, rotation, zooming, cropping, panning, minor color changes. Which types of augmentation are appropriate depends on your data. For regular photos, for instance, you'll want to use horizontal flipping, but not vertical flipping (since an upside down car is much less common than a car the right way up, for instance!) We recommend always using at least some light data augmentation, unless you have so much data that your model will never see the same input twice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment