Last active
February 9, 2019 15:57
-
-
Save vfdev-5/05648eaf8ffb71a6f2dd0ee932e22dea to your computer and use it in GitHub Desktop.
k-folds training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # coding: utf-8 | |
| # Example cross-validation training convnet on CIFAR10 | |
| from tensorpack import * | |
| import tensorflow as tf | |
| import argparse | |
| import numpy as np | |
| import os | |
| import tensorpack.tfutils.symbolic_functions as symbf | |
| from tensorpack.tfutils.summary import * | |
| from tensorpack.dataflow import dataset | |
| class Model(ModelDesc): | |
| def __init__(self, cifar_classnum): | |
| super(Model, self).__init__() | |
| self.cifar_classnum = cifar_classnum | |
| def _get_inputs(self): | |
| return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'), | |
| InputDesc(tf.int32, (None,), 'label') | |
| ] | |
| def _build_graph(self, inputs): | |
| image, label = inputs | |
| is_training = get_current_tower_context().is_training | |
| keep_prob = tf.constant(0.5 if is_training else 1.0) | |
| if is_training: | |
| tf.summary.image("train_image", image, 10) | |
| if tf.test.is_gpu_available(): | |
| image = tf.transpose(image, [0, 3, 1, 2]) | |
| data_format = 'NCHW' | |
| else: | |
| data_format = 'NHWC' | |
| image = image / 4.0 # just to make range smaller | |
| with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3),\ | |
| argscope([Conv2D, MaxPooling, BatchNorm], data_format=data_format): | |
| logits = LinearWrap(image)\ | |
| .Conv2D('conv1.1', out_channel=64)\ | |
| .Conv2D('conv1.2', out_channel=64)\ | |
| .MaxPooling('pool1', 3, stride=2, padding='SAME')\ | |
| .Conv2D('conv2.1', out_channel=128)\ | |
| .Conv2D('conv2.2', out_channel=128)\ | |
| .MaxPooling('pool2', 3, stride=2, padding='SAME')\ | |
| .Conv2D('conv3.1', out_channel=128, padding='VALID')\ | |
| .Conv2D('conv3.2', out_channel=128, padding='VALID')\ | |
| .FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu)\ | |
| .tf.nn.dropout(keep_prob)\ | |
| .FullyConnected('fc1', 512, nl=tf.nn.relu)\ | |
| .FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)() | |
| cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) | |
| cost = tf.reduce_mean(cost, name='cross_entropy_loss') | |
| wrong = symbf.prediction_incorrect(logits, label) | |
| # monitor training error | |
| add_moving_summary(tf.reduce_mean(wrong, name='train_error')) | |
| # weight decay on all W of fc layers | |
| wd_cost = regularize_cost('fc.*/W', l2_regularizer(4e-4), name='regularize_loss') | |
| add_moving_summary(cost, wd_cost) | |
| add_param_summary(('.*/W', ['histogram'])) # monitor W | |
| self.cost = tf.add_n([cost, wd_cost], name='cost') | |
| def _get_optimizer(self): | |
| lr = symbf.get_scalar_var('learning_rate', 1e-2, summary=True) | |
| return tf.train.AdamOptimizer(lr, epsilon=1e-3) | |
| from sklearn.model_selection import KFold | |
| class KFoldsDataset(ProxyDataFlow): | |
| """ | |
| Extract k-fold train/val datasets | |
| """ | |
| def __init__(self, ds, data_type, fold_index, n_folds): | |
| """ | |
| Args: | |
| ds (DataFlow): input DataFlow | |
| data_type: train or val | |
| fold_index: fold index | |
| n_folds: number of folds for K-Folds train/val dataset extraction | |
| """ | |
| assert fold_index >= 0 or fold_index < n_folds | |
| assert data_type in ['train', 'val'] | |
| super(KFoldsDataset, self).__init__(ds) | |
| kfs = KFold(n_splits=n_folds) | |
| index = 0 | |
| for train_indices, val_indices in kfs.split(range(ds.size())): | |
| if index != fold_index: | |
| index += 1 | |
| continue | |
| self.indices = train_indices if data_type == 'train' else val_indices | |
| break | |
| def size(self): | |
| return len(self.indices) | |
| def get_data(self): | |
| index = 0 | |
| for dp in self.ds.get_data(): | |
| if index in self.indices: | |
| yield dp | |
| index += 1 | |
| def get_data(train_or_test, val_fold_index=0, n_folds=5, cifar_classnum=10): | |
| isTrain = train_or_test == 'train' | |
| if cifar_classnum == 10: | |
| ds = dataset.Cifar10(train_or_test) | |
| else: | |
| ds = dataset.Cifar100(train_or_test) | |
| if isTrain: | |
| ds = KFoldsDataset(ds, 'train', fold_index=val_fold_index, n_folds=n_folds) | |
| else: | |
| ds = KFoldsDataset(ds, 'val', fold_index=val_fold_index, n_folds=n_folds) | |
| if isTrain: | |
| augmentors = [ | |
| imgaug.RandomCrop((30, 30)), | |
| imgaug.Flip(horiz=True), | |
| imgaug.Brightness(63), | |
| imgaug.Contrast((0.2, 1.8)), | |
| imgaug.GaussianDeform( | |
| [(0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2)], | |
| (30, 30), 0.2, 3), | |
| imgaug.MeanVarianceNormalize(all_channel=True) | |
| ] | |
| else: | |
| augmentors = [ | |
| imgaug.CenterCrop((30, 30)), | |
| imgaug.MeanVarianceNormalize(all_channel=True) | |
| ] | |
| ds = AugmentImageComponent(ds, augmentors) | |
| ds = BatchData(ds, 128, remainder=not isTrain) | |
| if isTrain: | |
| ds = PrefetchDataZMQ(ds, 5) | |
| return ds | |
| def get_config(val_fold_index=0, n_folds=5, cifar_classnum=10): | |
| # prepare dataset | |
| dataset_train = get_data(train_or_test='train', | |
| val_fold_index=val_fold_index, | |
| n_folds=n_folds, | |
| cifar_classnum=cifar_classnum) | |
| dataset_val = get_data(train_or_test='test', | |
| val_fold_index=val_fold_index, | |
| n_folds=n_folds, | |
| cifar_classnum=cifar_classnum) | |
| def lr_func(lr): | |
| if lr < 5e-3: | |
| raise StopTraining() | |
| return lr * 0.31 | |
| return TrainConfig( | |
| model=Model(cifar_classnum), | |
| dataflow=dataset_train, | |
| callbacks=[ | |
| ModelSaver(), | |
| InferenceRunner(dataset_val, ClassificationError()), | |
| StatMonitorParamSetter('learning_rate', 'val_error', lr_func, | |
| threshold=0.001, last_k=10), | |
| ], | |
| max_epoch=2, | |
| ) | |
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| n_folds = 5 | |
| with tf.Graph().as_default(): | |
| for val_fold_index in range(n_folds): | |
| tf.reset_default_graph() | |
| logger.set_logger_dir(os.path.join('train_log', 'cifar10_fold_%i' % val_fold_index)) | |
| config = get_config(val_fold_index=val_fold_index, n_folds=n_folds, cifar_classnum=10) | |
| trainer = QueueInputTrainer(config) | |
| trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment