Created
October 19, 2018 04:11
-
-
Save prl900/a49cffd5c3178ca31b3083eff7a12d18 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import load_model, Sequential | |
from keras import layers | |
from keras.layers import Layer | |
from keras import models | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
x = np.load("/datasets/10zlevels.npy") | |
y = 1000*np.expand_dims(np.load("/datasets/1980-2016/full_tp_1980_2016.npy"), axis=3) | |
print "data loaded" | |
print x.shape | |
levels = [0,2,6] | |
for i in range(5000,5050): | |
plt.imsave('in_{:04d}.png'.format(i), x[i,:,:,0], cmap='jet') | |
plt.imsave('out_{:04d}.png'.format(i), y[i,:,:,0], cmap='Blues') | |
sys.exit(0) | |
vgg16 = load_model('/datasets/vgg16.h5') | |
#segnet = load_model('/datasets/segnet.h5') | |
unet = load_model('/datasets/unet.h5') | |
t_mae_unet = 0 | |
t_mae_vgg16 = 0 | |
for i in range(40000, 40100): | |
print(i) | |
geop = x[i][:,:,levels] | |
geop = geop[np.newaxis, ...] | |
print geop.shape | |
print "rain max", y[i,:,:,0].max() | |
print "rain sum", y[i,:,:,0].sum() | |
plt.imsave('in_{}.png'.format(i), x[i,:,:,0], cmap='jet') | |
plt.imsave('out_{}.png'.format(i), y[i,:,:,0], cmap='Blues') | |
rain = vgg16.predict(geop)#/1000 | |
print "vgg16 max", rain.max() | |
print "vgg16 sum", rain.sum() | |
mae_vgg16 = np.mean(np.absolute((rain[0,:,:,0] - y[i,:,:,0]))) | |
print "vgg16 mae", mae_vgg16 | |
t_mae_vgg16 += mae_vgg16 | |
#plt.imsave('vgg16_{}.png'.format(i), rain[0,:,:,0], cmap='Blues') | |
#rain = segnet.predict(geop)/1000 | |
#plt.imsave('segnet_{}.png'.format(i), rain[0,:,:,0], cmap='Blues') | |
rain = unet.predict(geop)#/1000 | |
print "unet max", rain.max() | |
print "unet sum", rain.sum() | |
mae_unet = np.mean(np.absolute((rain[0,:,:,0] - y[i,:,:,0]))) | |
print "unet mae", mae_unet | |
t_mae_unet += mae_unet | |
#plt.imsave('unet_{}.png'.format(i), rain[0,:,:,0], cmap='Blues') | |
print t_mae_vgg16/100.0 | |
print t_mae_unet/100.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment