Last active
January 25, 2021 15:33
-
-
Save drscotthawley/fb69c3dac447c9c31db79a6bc8d7fd22 to your computer and use it in GitHub Desktop.
Swaps Validation set with a section of Training set, given a value for k
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# In case you didn't think to add k-fold cross-validation until late in your | |
# ML project,... | |
# This is built for a situation where datasets are arrays of, say, images. | |
def kfold_swap(train_X, train_Y, val_X, val_Y, k): | |
""" | |
Swaps val with a section of train, given a value for k | |
"Duct tape" approach used to "retro-fit" k-fold cross-validation while minimally | |
disturbing the rest of the code, while avoiding reloading data from disk and | |
keeping RAM use manageable. (e.g. np.append() is bad b/c it would copy all of train) | |
"Not-quite in-place" swapping means only a val-sized section of train gets duplicated in storage. | |
For 80-20 train/val split, k can run from 0 to 4 (= 5-fold cross-val) | |
For 80-10-10 train/val/test split, k can run from 0 to 8 (= 9-fold cross-val) | |
For 70-15-15 train/val/test split, exceeding k=4 will give you a failed assertion | |
""" | |
if k > 0: # k=0 means do nothing | |
vl = val_X.shape[0] | |
# sanity checks: make sure sizes are ok | |
assert train_X.shape[0] == train_Y.shape[0] | |
assert val_X.shape[0] == val_Y.shape[0] | |
assert k*vl <= train_X.shape[0] | |
bgn, end = (k-1)*vl, k*vl # minus sign is from choice that k=0 is no-op | |
val_X, train_X[bgn:end,:,:] = (train_X[bgn:end,:,:]).copy(), val_X | |
val_Y, train_Y[bgn:end,:,:] = (train_Y[bgn:end,:,:]).copy(), val_Y | |
return train_X, train_Y, val_X, val_Y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment