Skip to content

Instantly share code, notes, and snippets.

@prl900
Created July 11, 2019 12:41
Show Gist options
  • Save prl900/42c3a6ccfedfddec99cf3e07fef26c00 to your computer and use it in GitHub Desktop.
Save prl900/42c3a6ccfedfddec99cf3e07fef26c00 to your computer and use it in GitHub Desktop.
from __future__ import print_function, division
from keras.layers import concatenate, Input, Concatenate, BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Model
from keras.optimizers import Adam
import xarray as xr
import numpy as np
import pickle
def get_UNet():
def conv2d(layer_input, filters, strides, f_size=5):
"""Layers used during downsampling"""
e = BatchNormalization()(layer_input)
return Conv2D(filters, kernel_size=f_size, strides=strides, padding='same', activation='relu')(e)
def deconv2d(layer_input, skip_input, filters, strides, f_size=5):
"""Layers used during upsampling"""
d = BatchNormalization()(layer_input)
d = UpSampling2D(size=strides)(d)
d = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(d)
d = Concatenate()([d, skip_input])
return d
n_filters = 32
# Image input
d0 = Input(shape=(240,360,3))
# Downsampling
d1 = conv2d(d0, n_filters, strides=2)
d2 = conv2d(d1, n_filters*2, strides=3)
d3 = conv2d(d2, n_filters*4, strides=2)
d4 = conv2d(d3, n_filters*8, strides=2)
#d5 = conv2d(d4, n_filters*8)
# Upsampling
#u1 = deconv2d(d5, d4, n_filters*8)
u2 = deconv2d(d4, d3, n_filters*4, strides=2)
u3 = deconv2d(u2, d2, n_filters*2, strides=2)
u4 = deconv2d(u3, d1, n_filters, strides=3)
u5 = UpSampling2D(size=2)(u4)
output_img = Conv2D(1, kernel_size=1, strides=1, padding='same', activation='relu')(u5)
return Model(inputs=d0, outputs=output_img)
# Load ERA5 total precipitation
ds = Dataset("ERA5_AU_Z.nc", "r")
x = ds['z'][:,:240,:360,:]
print(x.shape)
ds = Dataset("ERA5_AU_TP.nc", "r")
y = 1000*ds['tp'][:,:240,:360][:,:,:,None]
print(y.shape)
print(y.min(), y.max(), y.mean())
y = np.clip(y, 0, 50)
idxs = np.arange(x.shape[0])
np.random.seed(0)
np.random.shuffle(idxs)
print(idxs)
x = x[idxs,:,:,:]
y = y[idxs,,:,::]
x_train = x[:5000,:,:,:]
y_train = y[:5000,:,:,:]
print(x_train.shape, y_train.shape)
x_test = x[5000:,:,:,:]
y_test = y[5000:,:,:,:]
print(x_test.shape, y_test.shape)
model = get_UNet()
print(model.summary())
model.compile(loss='mse', optimizer=Adam(0.0001))
history = model.fit(x=x_train, y=y_train, epochs=100, batch_size=10, validation_data=(x_test, y_test))
with open('unet2_era5.pkl', 'wb') as f:
pickle.dump(history.history, f)
model.save('unet2_era5.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment