Last active
April 24, 2017 16:36
-
-
Save ethen8181/5f13940baabd2a975a72a27c7a8ce12c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from hw1_nnet import NeuralNet | |
from keras.datasets.mnist import load_data | |
(X_train, y_train), (X_test, y_test) = load_data() | |
X_train = X_train.reshape((X_train.shape[0], -1)) / 255.0 | |
X_test = X_test.reshape((X_test.shape[0], -1)) / 255.0 | |
nn_params = { | |
'reg': 0.01, | |
'seed': 1234, | |
'n_iters': 50, | |
'hidden_dims': [256, 128], | |
'learning_rate': 0.01, | |
'activation': 'relu', | |
'filename': 'learning_rate: 0.01' | |
} | |
learning_rate_options = [0.01, 0.05, 0.1] | |
for learning_rate in learning_rate_options: | |
param = { | |
'learning_rate': learning_rate, | |
'filename': 'learning_rate: {}'.format(learning_rate) | |
} | |
nn = NeuralNet(**nn_params) | |
nn.set_params(**param) | |
nn.fit(X_train, y_train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment