Last active
March 18, 2021 12:34
-
-
Save oscar-defelice/86d5442f41614fd20af5dd8251bca076 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
config = { | |
'data': data, | |
'train_test_ratio': 0.2, | |
'model_config': { | |
'model_name': 'house_pricing_model', | |
'layers': { | |
'first_layer': 12, | |
'second_layer': 5, | |
'output_layer': 1 | |
}, | |
'activations': ['relu', 'relu', None], | |
'loss_function': 'mse', | |
'optimiser': 'adam', | |
'metrics': ['mae'] | |
}, | |
'training_config': { | |
'batch_size': 10, | |
'epochs': 8 | |
} | |
} | |
def train(model, config): | |
""" | |
Training function over data. | |
It takes a compiled model and trains it over data. | |
Returns: | |
model.history | |
""" | |
batch_size, epochs = config['training_config']['batch_size'], config['training_config']['epochs'] | |
X_train, X_test, Y_train, Y_test = import_data(config) | |
hist = model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_test, Y_test)) | |
return hist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment