Skip to content

Instantly share code, notes, and snippets.

@chupvl
Created November 2, 2021 16:30
Show Gist options
  • Select an option

  • Save chupvl/b44df587aae51cee6c02ca796a28d63c to your computer and use it in GitHub Desktop.

Select an option

Save chupvl/b44df587aae51cee6c02ca796a28d63c to your computer and use it in GitHub Desktop.
def ncv_runs_table(number_of_folds):
'''
ez function to generate nested-cross-validation runs
mostly for explanation purposes
'''
from itertools import combinations
folds = range(number_of_folds)
lst = []
c = 0
for i in list(combinations(folds, 2)):
d = {}
d['validation'] = [i[0]]
d['test'] = [i[1]]
d['train'] = list(set(folds).difference(set(i)))
c = c + 1
d['run'] = f'run{c}'
lst.append(d)
d2 = {}
d2['validation'] = [i[1]]
d2['test'] = [i[0]]
d2['train'] = list(set(folds).difference(set(i)))
c = c + 1
d2['run'] = f'run{c}'
lst.append(d2)
return pd.DataFrame(lst)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment