Skip to content

Instantly share code, notes, and snippets.

@jinyu121
Created October 3, 2017 07:17
Show Gist options
  • Save jinyu121/2c2e22f984a993a86f11f1be98f1d14c to your computer and use it in GitHub Desktop.
Save jinyu121/2c2e22f984a993a86f11f1be98f1d14c to your computer and use it in GitHub Desktop.
TensorFlow图片分类例子

TensorFlow 分类例子

数据准备

不需要转换成TfRecord格式。只需要两个txt,每个txt都是图片路径 分类id的形式。

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib.data import Dataset
from tensorflow.python.framework import dtypes
from tensorflow.python.framework.ops import convert_to_tensor
class ImageDataGenerator(object):
'''
使用文件名列表创建数据集
dataset = ImageDataGenerator(...)
dataset = dataset.data
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
'''
def __init__(self, txt_file, batch_size, num_classes,
mode=tf.estimator.ModeKeys.TRAIN, image_size=None,
one_hot=True, buffer_scale=100):
self.image_size = image_size
self.one_hot = one_hot
self.txt_file = txt_file
self.num_classes = num_classes
buffer_size = batch_size * buffer_scale
# 从txt文件里面将文件名和类别读出来
self._read_txt_file()
# 获取长度
self.data_size = len(self.labels)
# 转换成Tensor
self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string)
self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32)
# 创建数据集
data = Dataset.from_tensor_slices((self.img_paths, self.labels))
# 如何映射和缓存
data = data.map(
self._parse_function,
num_threads=8,
output_buffer_size=buffer_size
)
# 如果是训练,那么打乱、重复
if mode == tf.estimator.ModeKeys.TRAIN:
data = data.repeat()
data = data.shuffle(buffer_size=buffer_size)
# 设置Batch
self.data = data.batch(batch_size)
def _read_txt_file(self):
"""读TXT文件,转换成文件名和标签的List"""
self.img_paths = []
self.labels = []
if isinstance(self.txt_file, str):
for line in open(self.txt_file, 'r'):
items = line.split(' ')
self.img_paths.append(items[0])
self.labels.append(int(items[1]))
elif isinstance(self.txt_file, list):
for f in self.txt_file:
for line in open(f, 'r'):
items = line.split(' ')
self.img_paths.append(items[0])
self.labels.append(int(items[1]))
else:
raise ValueError('Filename should be string or string list.')
def _parse_function(self, filename, label):
"""定义文件名到图片数据的转换"""
# 是否需要OneHot编码
if self.one_hot:
lab = tf.one_hot(label, self.num_classes)
else:
lab = label
# 读取和转换图片
img = tf.read_file(filename)
img = tf.image.decode_jpeg(img, channels=3)
# 预处理
img = tf.image.per_image_standardization(img)
# 转换成固定大小
if self.image_size is not None:
img = tf.image.resize_images(img, self.image_size)
return img, lab
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import models.official.resnet.resnet_model as resnet_model
from dataset import ImageDataGenerator
# 与模型无关的参数
flags = tf.app.flags
flags.DEFINE_string(flag_name='data_dir', default_value='/tmp/cifar10_data', docstring='数据集地址')
flags.DEFINE_string(flag_name='model_dir', default_value='/tmp/cifar10_model', docstring='模型存放地址')
flags.DEFINE_integer(flag_name='resnet_size', default_value=32, docstring='The size of the ResNet model to use.')
flags.DEFINE_integer(flag_name='train_steps', default_value=100000, docstring='训练步数')
flags.DEFINE_integer(flag_name='steps_per_eval', default_value=4000, docstring='测试间隔')
flags.DEFINE_integer(flag_name='batch_size', default_value=128, docstring='Batch大小')
FLAGS = flags.FLAGS
HEIGHT = 32
WIDTH = 32
DEPTH = 3
NUM_CLASSES = 10
NUM_DATA_BATCHES = 5
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
# Scale the learning rate linearly with the batch size. When the batch size is 128, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 128
_MOMENTUM = 0.9
# We use a weight decay of 0.0002, which performs better than the 0.0001 that was originally suggested.
_WEIGHT_DECAY = 2e-4
_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
def filenames(mode):
if mode == tf.estimator.ModeKeys.TRAIN:
return ["/home/haoyu/Datasets/flowers/train.txt"]
elif mode == tf.estimator.ModeKeys.EVAL:
return ["/home/haoyu/Datasets/flowers/test.txt"]
else:
raise ValueError('Invalid mode: %s' % mode)
def input_fn(mode):
"""定义如何输入"""
dataset = ImageDataGenerator(filenames(mode), FLAGS.batch_size, NUM_CLASSES,
mode=mode, image_size=[HEIGHT, WIDTH])
dataset = dataset.data
# 做iterator
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
def model_fn(features, labels, mode):
"""定义模型"""
tf.summary.image('images', features, max_outputs=6)
network = resnet_model.cifar10_resnet_v2_generator(FLAGS.resnet_size, NUM_CLASSES)
inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH])
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# 计算Loss,并将其可视化
cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# 在Loss中加入weight decay
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
# 制作TrainOP
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
# 在第100、150、200epoch的时候降低学习率.
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]]
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]]
learning_rate = tf.train.piecewise_constant(
tf.cast(global_step, tf.int32), boundaries, values)
# 记录并可视化LearningRate
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
# 创建优化器
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=_MOMENTUM)
# Batch norm requires update ops to be added as a dependency to the train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
else:
train_op = None
accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
metrics = {'accuracy': accuracy}
# 记录并可视化模型精度
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
def main(_):
# 加上这一句可以小幅加速
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
cifar_classifier = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=FLAGS.model_dir)
for cycle in range(FLAGS.train_steps // FLAGS.steps_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=100)
cifar_classifier.train(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN),
steps=FLAGS.steps_per_eval,
hooks=[logging_hook])
# 测试模型
eval_results = cifar_classifier.evaluate(
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))
print(eval_results)
if '__main__' == __name__:
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
@ccnyou
Copy link

ccnyou commented Jul 3, 2018

赞。提个小建议哈,能不能做成git project,方便一些。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment