Created
July 11, 2019 12:41
-
-
Save prl900/42c3a6ccfedfddec99cf3e07fef26c00 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
from keras.layers import concatenate, Input, Concatenate, BatchNormalization | |
from keras.layers.convolutional import UpSampling2D, Conv2D | |
from keras.models import Model | |
from keras.optimizers import Adam | |
import xarray as xr | |
import numpy as np | |
import pickle | |
def get_UNet(): | |
def conv2d(layer_input, filters, strides, f_size=5): | |
"""Layers used during downsampling""" | |
e = BatchNormalization()(layer_input) | |
return Conv2D(filters, kernel_size=f_size, strides=strides, padding='same', activation='relu')(e) | |
def deconv2d(layer_input, skip_input, filters, strides, f_size=5): | |
"""Layers used during upsampling""" | |
d = BatchNormalization()(layer_input) | |
d = UpSampling2D(size=strides)(d) | |
d = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(d) | |
d = Concatenate()([d, skip_input]) | |
return d | |
n_filters = 32 | |
# Image input | |
d0 = Input(shape=(240,360,3)) | |
# Downsampling | |
d1 = conv2d(d0, n_filters, strides=2) | |
d2 = conv2d(d1, n_filters*2, strides=3) | |
d3 = conv2d(d2, n_filters*4, strides=2) | |
d4 = conv2d(d3, n_filters*8, strides=2) | |
#d5 = conv2d(d4, n_filters*8) | |
# Upsampling | |
#u1 = deconv2d(d5, d4, n_filters*8) | |
u2 = deconv2d(d4, d3, n_filters*4, strides=2) | |
u3 = deconv2d(u2, d2, n_filters*2, strides=2) | |
u4 = deconv2d(u3, d1, n_filters, strides=3) | |
u5 = UpSampling2D(size=2)(u4) | |
output_img = Conv2D(1, kernel_size=1, strides=1, padding='same', activation='relu')(u5) | |
return Model(inputs=d0, outputs=output_img) | |
# Load ERA5 total precipitation | |
ds = Dataset("ERA5_AU_Z.nc", "r") | |
x = ds['z'][:,:240,:360,:] | |
print(x.shape) | |
ds = Dataset("ERA5_AU_TP.nc", "r") | |
y = 1000*ds['tp'][:,:240,:360][:,:,:,None] | |
print(y.shape) | |
print(y.min(), y.max(), y.mean()) | |
y = np.clip(y, 0, 50) | |
idxs = np.arange(x.shape[0]) | |
np.random.seed(0) | |
np.random.shuffle(idxs) | |
print(idxs) | |
x = x[idxs,:,:,:] | |
y = y[idxs,,:,::] | |
x_train = x[:5000,:,:,:] | |
y_train = y[:5000,:,:,:] | |
print(x_train.shape, y_train.shape) | |
x_test = x[5000:,:,:,:] | |
y_test = y[5000:,:,:,:] | |
print(x_test.shape, y_test.shape) | |
model = get_UNet() | |
print(model.summary()) | |
model.compile(loss='mse', optimizer=Adam(0.0001)) | |
history = model.fit(x=x_train, y=y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test)) | |
with open('unet2_era5.pkl', 'wb') as f: | |
pickle.dump(history.history, f) | |
model.save('unet2_era5.h5') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment