Last active
April 22, 2020 00:56
-
-
Save TimSC/debcf71eae41c5b54eaf44d587d7744c to your computer and use it in GitHub Desktop.
Using Keras to tackle the Inria aerial image labeling dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Using Keras to tackle the Inria aerial image labeling dataset | |
# https://project.inria.fr/aerialimagelabeling/ | |
import os | |
#Work around for https://github.com/tensorflow/tensorflow/issues/24496 | |
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' | |
# Work around for https://github.com/tensorflow/tensorflow/issues/33024 | |
import tensorflow.compat as compat | |
compat.v1.disable_eager_execution() | |
import zipfile | |
import imageio | |
import io | |
from keras_segmentation.models.fcn import fcn_8 | |
from keras_segmentation.models.unet import vgg_unet | |
import tensorflow.keras as keras | |
import numpy as np | |
from matplotlib import pyplot as plt | |
def AerialListFiles(pth): | |
z = zipfile.ZipFile(pth) | |
zinfo = z.infolist() | |
zinfoDict = {} | |
for zi in zinfo: | |
zinfoDict[zi.filename] = zi | |
testLocs = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"] | |
trainLocs = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"] | |
xData, xTest, yData, yTest = [], [], [], [] | |
for loc in trainLocs: | |
for i in range(1, 37): | |
ifile = "AerialImageDataset/train/images/{}{}.tif".format(loc, i) | |
gtFile = "AerialImageDataset/train/gt/{}{}.tif".format(loc, i) | |
xData.append(zinfoDict[ifile]) | |
yData.append(zinfoDict[gtFile]) | |
for loc in testLocs: | |
for i in range(1, 37): | |
ifile = "AerialImageDataset/test/images/{}{}.tif".format(loc, i) | |
xTest.append(zinfoDict[ifile]) | |
yTest.append(None) | |
return z, xData, yData, yData, yTest | |
def SplitTrainAndValidation(xData, yData): | |
xTrain, xVal, yTrain, yVal = [], [], [], [] | |
n = 0 | |
for locNum in range(5): | |
for i in range(1, 37): | |
if i > 5: | |
xTrain.append(xData[n]) | |
yTrain.append(yData[n]) | |
else: | |
xVal.append(xData[n]) | |
yVal.append(yData[n]) | |
n += 1 | |
return xTrain, xVal, yTrain, yVal | |
class AerialDataGenerator(keras.utils.Sequence): | |
'Generates data for Keras' | |
def __init__(self, z, dataX, dataY, batchesPerEpoch=50, filesInBatch=10, cropsInFile=10, cropSize=128, cropMargin=-32): | |
'Initialization' | |
self.z = z | |
self.dataX = dataX | |
self.dataY = dataY | |
self.batchesPerEpoch = batchesPerEpoch | |
self.filesInBatch = filesInBatch | |
self.cropsInFile = cropsInFile | |
self.cropSize = cropSize | |
self.cropMargin = cropMargin | |
def __len__(self): | |
'Denotes the number of batches per epoch' | |
return self.batchesPerEpoch | |
def __getitem__(self, index): | |
'Generate one batch of data' | |
dataX, dataY = [], [] | |
sizeWithMargin = self.cropSize + 2 * self.cropMargin | |
posMarginOrZero = max(0, self.cropMargin) | |
for j in range(self.filesInBatch): | |
fileId = np.random.randint(len(self.dataX)) | |
imgData = io.BytesIO(z.open(self.dataX[fileId]).read()) | |
img = imageio.imread(imgData) | |
del imgData | |
gtData = io.BytesIO(z.open(self.dataY[fileId]).read()) | |
gt = imageio.imread(gtData) | |
del gtData | |
for i in range(self.cropsInFile): | |
#Get a random crop | |
r = np.random.randint(posMarginOrZero, img.shape[0]-self.cropSize-posMarginOrZero) | |
c = np.random.randint(posMarginOrZero, img.shape[1]-self.cropSize-posMarginOrZero) | |
imgc = img[r:r+self.cropSize,:,:] | |
imgc = imgc[:,c:c+self.cropSize,:] | |
gtc = gt[r-self.cropMargin:r+self.cropSize+self.cropMargin,:] | |
gtc = gtc[:,c-self.cropMargin:c+self.cropSize+self.cropMargin] | |
#Rescale | |
imgc = np.array(imgc, dtype=np.float32) / 255.0 | |
gtc = gtc.reshape((sizeWithMargin*sizeWithMargin,)) > 128 | |
gtc = keras.utils.to_categorical(gtc, num_classes=2) | |
dataX.append(imgc) | |
dataY.append(gtc) | |
dataX = np.array(dataX) | |
dataY = np.array(dataY) | |
return dataX, dataY | |
def on_epoch_end(self): | |
pass | |
# plot diagnostic learning curves | |
def summarize_diagnostics(histories): | |
for i in range(len(histories)): | |
# plot loss | |
plt.subplot(2, 1, 1) | |
plt.title('Cross Entropy Loss') | |
plt.plot(histories[i].history['loss'], color='blue', label='train') | |
plt.plot(histories[i].history['val_loss'], color='orange', label='test') | |
# plot accuracy | |
plt.subplot(2, 1, 2) | |
plt.title('Classification Accuracy') | |
plt.plot(histories[i].history['accuracy'], color='blue', label='train') | |
plt.plot(histories[i].history['val_accuracy'], color='orange', label='test') | |
plt.show() | |
def PredictOnImages(z, xVal, model): | |
# Do prediction on specified images | |
for imgInfo in xVal: | |
print (imgInfo) | |
imgData = io.BytesIO(z.open(imgInfo).read()) | |
img = imageio.imread(imgData) | |
del imgData | |
predImg = np.zeros((img.shape[0], img.shape[1]), dtype=np.int8) | |
margin = 32 | |
for r in range(margin, img.shape[0]-64-margin, 64): | |
patches = [] | |
patchPosLi = [] | |
for c in range(margin, img.shape[1]-64-margin, 64): | |
#print (r, c) | |
imgc = img[r-margin:r+64+margin,:,:] | |
imgc = imgc[:,c-margin:c+64+margin,:] | |
patches.append(imgc) | |
patchPosLi.append((r, c)) | |
patches = np.array(patches) | |
result = model.predict(patches) | |
for (r, c), pred in zip(patchPosLi, result): | |
outp = predImg[r:r+64,:] | |
outp = outp[:,c:c+64] | |
pred = pred.reshape((64, 64, 2)) | |
pred = pred[:,:,1] | |
print (r, c, pred.shape, outp.shape) | |
predImg[r:r+64,c:c+64] = (pred > 0.5) | |
plt.imshow(predImg) | |
plt.show() | |
if __name__=="__main__": | |
z, xData, yData, yData, yTest = AerialListFiles("/home/tim/Downloads/aerialimagelabeling/NEW2-AerialImageDataset.zip") | |
xTrain, xVal, yTrain, yVal = SplitTrainAndValidation(xData, yData) | |
if True: | |
trainGen = AerialDataGenerator(z, xTrain, yTrain) | |
valGen = AerialDataGenerator(z, xVal, yVal, filesInBatch=1, cropsInFile=100) | |
model = vgg_unet(2, input_height=128, input_width=128) | |
print (type(model)) | |
print ("Compiling model") | |
model.compile(optimizer="adadelta", loss='categorical_crossentropy', metrics=['accuracy']) | |
print ("Fitting model") | |
history = model.fit_generator( | |
trainGen, | |
validation_data=valGen, | |
validation_steps=50, | |
epochs=20, | |
workers=0 | |
) | |
#keras.models.save_model(model, 'aerial.h5') | |
summarize_diagnostics([history]) | |
else: | |
model = keras.models.load_model('aerial.h5') | |
PredictOnImages(z, xVal, model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment