Created
July 10, 2016 11:52
-
-
Save Piyush3dB/f59e64c1d1e16afc10340c16439daef1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from keras.models import Model | |
from keras.layers import Convolution2D, Activation, Input | |
#%matplotlib inline | |
#plt.rcParams['image.cmap'] = 'gray' | |
def gen_image(): | |
'''Generate a 32x32 image with a 16x16 block of Gaussian noise.''' | |
im = np.zeros((32, 32)) | |
mask = np.zeros((32, 32)) | |
y, x = np.random.randint(0, 16, 2) | |
im[y:y+16, x:x+16] = np.random.randn(16, 16) | |
mask[y:y+16, x:x+16] = 1. | |
return im, mask | |
def gen_dataset(n): | |
'''Generate a dataset of n examples.''' | |
im = np.ndarray((n, 1, 32, 32)) | |
mask = np.ndarray((n, 1, 32, 32)) | |
for i in range(n): | |
im[i,0], mask[i,0] = gen_image() | |
return im, mask | |
# Generate training and validation datasets. | |
im_train, mask_train = gen_dataset(10000) | |
im_val, mask_val = gen_dataset(1000) | |
# Show a couple examples from the training dataset. | |
for ix in [100, 101]: | |
plt.figure() | |
plt.subplot(1, 2, 1) | |
plt.imshow(im_train[ix,0]) | |
plt.subplot(1, 2, 2) | |
plt.imshow(mask_train[ix,0]) | |
#plt.show() | |
# Build network | |
inp = Input((1, 32, 32)) | |
out = Convolution2D(32, 3, 3, border_mode='same')(inp) | |
out = Activation('relu')(out) | |
out = Convolution2D(1, 3, 3, border_mode='same')(out) | |
out = Activation('sigmoid')(out) | |
model = Model(inp, out) | |
model.compile(loss='binary_crossentropy', optimizer='adam') | |
# Fit the training dataset, printing out validation score to make | |
# sure we don't overfit. | |
model.fit(im_train, mask_train, nb_epoch=10, validation_data=(im_val, mask_val)) | |
# Predict masks for validation dataset. | |
pred_val = model.predict(im_val) | |
# Show a couple example predictions. | |
for ix in [0, 1, 2]: | |
plt.figure() | |
plt.subplot(1, 3, 1) | |
plt.imshow(im_val[ix,0]) | |
plt.subplot(1, 3, 2) | |
plt.imshow(mask_val[ix,0]) | |
plt.subplot(1, 3, 3) | |
plt.imshow(pred_val[ix,0]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment