Skip to content

Instantly share code, notes, and snippets.

@durka
Last active May 17, 2018 05:18
Show Gist options
  • Save durka/136d77b0e1faf7edf787d9e5a2412f28 to your computer and use it in GitHub Desktop.
Save durka/136d77b0e1faf7edf787d9e5a2412f28 to your computer and use it in GitHub Desktop.
N = sum(map(len, cats.values()))
splits = map(lambda (tr,te): te, StratifiedKFold(n_splits=5, shuffle=True).split(np.zeros(N), reduce(lambda a,b: a+b, map(lambda (c,s): [c]*len(s), cats.items()))))
stats = {'train': dict(), 'test': dict()}
for cat, surfs in cats.items():
print cat
for i, (row, props, rates) in enumerate(surfs):
for s in range(4):
if i in splits[s]:
mode = 0 # train
fold = s
break
else:
if i in splits[4]:
mode = 1 # test
fold = None
else:
raise 'index not in any split: %d' % i
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment