Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active February 9, 2019 15:57
Show Gist options
  • Select an option

  • Save vfdev-5/05648eaf8ffb71a6f2dd0ee932e22dea to your computer and use it in GitHub Desktop.

Select an option

Save vfdev-5/05648eaf8ffb71a6f2dd0ee932e22dea to your computer and use it in GitHub Desktop.
k-folds training
# coding: utf-8
# Example cross-validation training convnet on CIFAR10
from tensorpack import *
import tensorflow as tf
import argparse
import numpy as np
import os
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
class Model(ModelDesc):
def __init__(self, cifar_classnum):
super(Model, self).__init__()
self.cifar_classnum = cifar_classnum
def _get_inputs(self):
return [InputDesc(tf.float32, (None, 30, 30, 3), 'input'),
InputDesc(tf.int32, (None,), 'label')
]
def _build_graph(self, inputs):
image, label = inputs
is_training = get_current_tower_context().is_training
keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training:
tf.summary.image("train_image", image, 10)
if tf.test.is_gpu_available():
image = tf.transpose(image, [0, 3, 1, 2])
data_format = 'NCHW'
else:
data_format = 'NHWC'
image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3),\
argscope([Conv2D, MaxPooling, BatchNorm], data_format=data_format):
logits = LinearWrap(image)\
.Conv2D('conv1.1', out_channel=64)\
.Conv2D('conv1.2', out_channel=64)\
.MaxPooling('pool1', 3, stride=2, padding='SAME')\
.Conv2D('conv2.1', out_channel=128)\
.Conv2D('conv2.2', out_channel=128)\
.MaxPooling('pool2', 3, stride=2, padding='SAME')\
.Conv2D('conv3.1', out_channel=128, padding='VALID')\
.Conv2D('conv3.2', out_channel=128, padding='VALID')\
.FullyConnected('fc0', 1024 + 512, nl=tf.nn.relu)\
.tf.nn.dropout(keep_prob)\
.FullyConnected('fc1', 512, nl=tf.nn.relu)\
.FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbf.prediction_incorrect(logits, label)
# monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(4e-4), name='regularize_loss')
add_moving_summary(cost, wd_cost)
add_param_summary(('.*/W', ['histogram'])) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 1e-2, summary=True)
return tf.train.AdamOptimizer(lr, epsilon=1e-3)
from sklearn.model_selection import KFold
class KFoldsDataset(ProxyDataFlow):
"""
Extract k-fold train/val datasets
"""
def __init__(self, ds, data_type, fold_index, n_folds):
"""
Args:
ds (DataFlow): input DataFlow
data_type: train or val
fold_index: fold index
n_folds: number of folds for K-Folds train/val dataset extraction
"""
assert fold_index >= 0 or fold_index < n_folds
assert data_type in ['train', 'val']
super(KFoldsDataset, self).__init__(ds)
kfs = KFold(n_splits=n_folds)
index = 0
for train_indices, val_indices in kfs.split(range(ds.size())):
if index != fold_index:
index += 1
continue
self.indices = train_indices if data_type == 'train' else val_indices
break
def size(self):
return len(self.indices)
def get_data(self):
index = 0
for dp in self.ds.get_data():
if index in self.indices:
yield dp
index += 1
def get_data(train_or_test, val_fold_index=0, n_folds=5, cifar_classnum=10):
isTrain = train_or_test == 'train'
if cifar_classnum == 10:
ds = dataset.Cifar10(train_or_test)
else:
ds = dataset.Cifar100(train_or_test)
if isTrain:
ds = KFoldsDataset(ds, 'train', fold_index=val_fold_index, n_folds=n_folds)
else:
ds = KFoldsDataset(ds, 'val', fold_index=val_fold_index, n_folds=n_folds)
if isTrain:
augmentors = [
imgaug.RandomCrop((30, 30)),
imgaug.Flip(horiz=True),
imgaug.Brightness(63),
imgaug.Contrast((0.2, 1.8)),
imgaug.GaussianDeform(
[(0.2, 0.2), (0.2, 0.8), (0.8, 0.8), (0.8, 0.2)],
(30, 30), 0.2, 3),
imgaug.MeanVarianceNormalize(all_channel=True)
]
else:
augmentors = [
imgaug.CenterCrop((30, 30)),
imgaug.MeanVarianceNormalize(all_channel=True)
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, 5)
return ds
def get_config(val_fold_index=0, n_folds=5, cifar_classnum=10):
# prepare dataset
dataset_train = get_data(train_or_test='train',
val_fold_index=val_fold_index,
n_folds=n_folds,
cifar_classnum=cifar_classnum)
dataset_val = get_data(train_or_test='test',
val_fold_index=val_fold_index,
n_folds=n_folds,
cifar_classnum=cifar_classnum)
def lr_func(lr):
if lr < 5e-3:
raise StopTraining()
return lr * 0.31
return TrainConfig(
model=Model(cifar_classnum),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
InferenceRunner(dataset_val, ClassificationError()),
StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10),
],
max_epoch=2,
)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
n_folds = 5
with tf.Graph().as_default():
for val_fold_index in range(n_folds):
tf.reset_default_graph()
logger.set_logger_dir(os.path.join('train_log', 'cifar10_fold_%i' % val_fold_index))
config = get_config(val_fold_index=val_fold_index, n_folds=n_folds, cifar_classnum=10)
trainer = QueueInputTrainer(config)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment