Skip to content

Instantly share code, notes, and snippets.

@AnderRasoVazquez
Created November 28, 2018 20:12
Show Gist options
  • Save AnderRasoVazquez/2e21fc0ee3fe81f5aeee09b3fc3d5b9c to your computer and use it in GitHub Desktop.
Save AnderRasoVazquez/2e21fc0ee3fe81f5aeee09b3fc3d5b9c to your computer and use it in GitHub Desktop.
[K-fold cross validation with Keras] #python #keras #machine_learning
"""Obtained from https://github.com/keras-team/keras/issues/1711#issuecomment-185801662"""
from sklearn.cross_validation import StratifiedKFold
def load_data():
# load your data using this function
def create model():
# create your model using this function
def train_and_evaluate__model(model, data[train], labels[train], data[test], labels[test)):
model.fit...
# fit and evaluate here.
if __name__ == "__main__":
n_folds = 10
data, labels, header_info = load_data()
skf = StratifiedKFold(labels, n_folds=n_folds, shuffle=True)
for i, (train, test) in enumerate(skf):
print "Running Fold", i+1, "/", n_folds
model = None # Clearing the NN.
model = create_model()
train_and_evaluate_model(model, data[train], labels[train], data[test], labels[test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment