Created
July 12, 2018 05:35
-
-
Save yohm/52abf95c077dc39d8a4cc7b8e2f85b52 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.datasets import boston_housing | |
from keras.models import Sequential | |
from keras.layers import Dense | |
from keras.optimizers import Adam | |
(x_train, y_train), (x_test, y_test) = boston_housing.load_data(test_split=0.0) | |
def scale_input(data): | |
from sklearn.preprocessing import StandardScaler | |
scaler = StandardScaler() | |
return scaler.fit_transform(data) | |
x_train = scale_input(x_train) | |
def baseline_model(): | |
# create model | |
model = Sequential() | |
model.add(Dense(13, input_dim=13, kernel_initializer='normal', activation='relu')) | |
model.add(Dense(1, kernel_initializer='normal')) | |
# Compile model | |
model.compile(loss='mean_squared_error', optimizer='adam') | |
return model | |
def larger_model(): | |
# create model | |
model = Sequential() | |
model.add(Dense(13, input_dim=13, kernel_initializer='normal', activation='relu')) | |
model.add(Dense(6, kernel_initializer='normal', activation='relu')) | |
model.add(Dense(1, kernel_initializer='normal')) | |
opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False) | |
# Compile model | |
model.compile(loss='mean_squared_error', optimizer=opt) | |
return model | |
def wider_model(): | |
# create model | |
model = Sequential() | |
model.add(Dense(20, input_dim=13, kernel_initializer='normal', activation='relu')) | |
model.add(Dense(1, kernel_initializer='normal')) | |
# Compile model | |
model.compile(loss='mean_squared_error', optimizer='adam') | |
return model | |
m = baseline_model() | |
#m = larger_model() | |
#m = wider_model() | |
h = m.fit(x_train, y_train, batch_size=20, epochs=300, verbose=1, validation_split=0.2) | |
def plot_history(h): | |
import matplotlib.pyplot as plt | |
plt.plot(h['loss'][-100:]) | |
plt.plot(h['val_loss'][-100:]) | |
plt.title('model loss') | |
plt.xlabel('epoch') | |
plt.ylabel('loss') | |
plt.legend(['loss', 'val_loss'], loc='best') | |
plt.show() | |
#print( m.evaluate(x_test, y_test, verbose=1) ) | |
plot_history(h.history) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment