不需要转换成TfRecord格式。只需要两个txt,每个txt都是图片路径 分类id
的形式。
Created
October 3, 2017 07:17
-
-
Save jinyu121/2c2e22f984a993a86f11f1be98f1d14c to your computer and use it in GitHub Desktop.
TensorFlow图片分类例子
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import tensorflow as tf | |
from tensorflow.contrib.data import Dataset | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.framework.ops import convert_to_tensor | |
class ImageDataGenerator(object): | |
''' | |
使用文件名列表创建数据集 | |
dataset = ImageDataGenerator(...) | |
dataset = dataset.data | |
iterator = dataset.make_one_shot_iterator() | |
images, labels = iterator.get_next() | |
''' | |
def __init__(self, txt_file, batch_size, num_classes, | |
mode=tf.estimator.ModeKeys.TRAIN, image_size=None, | |
one_hot=True, buffer_scale=100): | |
self.image_size = image_size | |
self.one_hot = one_hot | |
self.txt_file = txt_file | |
self.num_classes = num_classes | |
buffer_size = batch_size * buffer_scale | |
# 从txt文件里面将文件名和类别读出来 | |
self._read_txt_file() | |
# 获取长度 | |
self.data_size = len(self.labels) | |
# 转换成Tensor | |
self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) | |
self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32) | |
# 创建数据集 | |
data = Dataset.from_tensor_slices((self.img_paths, self.labels)) | |
# 如何映射和缓存 | |
data = data.map( | |
self._parse_function, | |
num_threads=8, | |
output_buffer_size=buffer_size | |
) | |
# 如果是训练,那么打乱、重复 | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
data = data.repeat() | |
data = data.shuffle(buffer_size=buffer_size) | |
# 设置Batch | |
self.data = data.batch(batch_size) | |
def _read_txt_file(self): | |
"""读TXT文件,转换成文件名和标签的List""" | |
self.img_paths = [] | |
self.labels = [] | |
if isinstance(self.txt_file, str): | |
for line in open(self.txt_file, 'r'): | |
items = line.split(' ') | |
self.img_paths.append(items[0]) | |
self.labels.append(int(items[1])) | |
elif isinstance(self.txt_file, list): | |
for f in self.txt_file: | |
for line in open(f, 'r'): | |
items = line.split(' ') | |
self.img_paths.append(items[0]) | |
self.labels.append(int(items[1])) | |
else: | |
raise ValueError('Filename should be string or string list.') | |
def _parse_function(self, filename, label): | |
"""定义文件名到图片数据的转换""" | |
# 是否需要OneHot编码 | |
if self.one_hot: | |
lab = tf.one_hot(label, self.num_classes) | |
else: | |
lab = label | |
# 读取和转换图片 | |
img = tf.read_file(filename) | |
img = tf.image.decode_jpeg(img, channels=3) | |
# 预处理 | |
img = tf.image.per_image_standardization(img) | |
# 转换成固定大小 | |
if self.image_size is not None: | |
img = tf.image.resize_images(img, self.image_size) | |
return img, lab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import tensorflow as tf | |
import models.official.resnet.resnet_model as resnet_model | |
from dataset import ImageDataGenerator | |
# 与模型无关的参数 | |
flags = tf.app.flags | |
flags.DEFINE_string(flag_name='data_dir', default_value='/tmp/cifar10_data', docstring='数据集地址') | |
flags.DEFINE_string(flag_name='model_dir', default_value='/tmp/cifar10_model', docstring='模型存放地址') | |
flags.DEFINE_integer(flag_name='resnet_size', default_value=32, docstring='The size of the ResNet model to use.') | |
flags.DEFINE_integer(flag_name='train_steps', default_value=100000, docstring='训练步数') | |
flags.DEFINE_integer(flag_name='steps_per_eval', default_value=4000, docstring='测试间隔') | |
flags.DEFINE_integer(flag_name='batch_size', default_value=128, docstring='Batch大小') | |
FLAGS = flags.FLAGS | |
HEIGHT = 32 | |
WIDTH = 32 | |
DEPTH = 3 | |
NUM_CLASSES = 10 | |
NUM_DATA_BATCHES = 5 | |
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES | |
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 | |
# Scale the learning rate linearly with the batch size. When the batch size is 128, the learning rate should be 0.1. | |
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 128 | |
_MOMENTUM = 0.9 | |
# We use a weight decay of 0.0002, which performs better than the 0.0001 that was originally suggested. | |
_WEIGHT_DECAY = 2e-4 | |
_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size | |
def filenames(mode): | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
return ["/home/haoyu/Datasets/flowers/train.txt"] | |
elif mode == tf.estimator.ModeKeys.EVAL: | |
return ["/home/haoyu/Datasets/flowers/test.txt"] | |
else: | |
raise ValueError('Invalid mode: %s' % mode) | |
def input_fn(mode): | |
"""定义如何输入""" | |
dataset = ImageDataGenerator(filenames(mode), FLAGS.batch_size, NUM_CLASSES, | |
mode=mode, image_size=[HEIGHT, WIDTH]) | |
dataset = dataset.data | |
# 做iterator | |
iterator = dataset.make_one_shot_iterator() | |
images, labels = iterator.get_next() | |
return images, labels | |
def model_fn(features, labels, mode): | |
"""定义模型""" | |
tf.summary.image('images', features, max_outputs=6) | |
network = resnet_model.cifar10_resnet_v2_generator(FLAGS.resnet_size, NUM_CLASSES) | |
inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH]) | |
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) | |
predictions = { | |
'classes': tf.argmax(logits, axis=1), | |
'probabilities': tf.nn.softmax(logits, name='softmax_tensor') | |
} | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) | |
# 计算Loss,并将其可视化 | |
cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) | |
tf.identity(cross_entropy, name='cross_entropy') | |
tf.summary.scalar('cross_entropy', cross_entropy) | |
# 在Loss中加入weight decay | |
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( | |
[tf.nn.l2_loss(v) for v in tf.trainable_variables()]) | |
# 制作TrainOP | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
global_step = tf.train.get_or_create_global_step() | |
# 在第100、150、200epoch的时候降低学习率. | |
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] | |
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]] | |
learning_rate = tf.train.piecewise_constant( | |
tf.cast(global_step, tf.int32), boundaries, values) | |
# 记录并可视化LearningRate | |
tf.identity(learning_rate, name='learning_rate') | |
tf.summary.scalar('learning_rate', learning_rate) | |
# 创建优化器 | |
optimizer = tf.train.MomentumOptimizer( | |
learning_rate=learning_rate, | |
momentum=_MOMENTUM) | |
# Batch norm requires update ops to be added as a dependency to the train_op | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
with tf.control_dependencies(update_ops): | |
train_op = optimizer.minimize(loss, global_step) | |
else: | |
train_op = None | |
accuracy = tf.metrics.accuracy( | |
tf.argmax(labels, axis=1), predictions['classes']) | |
metrics = {'accuracy': accuracy} | |
# 记录并可视化模型精度 | |
tf.identity(accuracy[1], name='train_accuracy') | |
tf.summary.scalar('train_accuracy', accuracy[1]) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
predictions=predictions, | |
loss=loss, | |
train_op=train_op, | |
eval_metric_ops=metrics) | |
def main(_): | |
# 加上这一句可以小幅加速 | |
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' | |
cifar_classifier = tf.estimator.Estimator( | |
model_fn=model_fn, | |
model_dir=FLAGS.model_dir) | |
for cycle in range(FLAGS.train_steps // FLAGS.steps_per_eval): | |
tensors_to_log = { | |
'learning_rate': 'learning_rate', | |
'cross_entropy': 'cross_entropy', | |
'train_accuracy': 'train_accuracy' | |
} | |
logging_hook = tf.train.LoggingTensorHook( | |
tensors=tensors_to_log, | |
every_n_iter=100) | |
cifar_classifier.train( | |
input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN), | |
steps=FLAGS.steps_per_eval, | |
hooks=[logging_hook]) | |
# 测试模型 | |
eval_results = cifar_classifier.evaluate( | |
input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL)) | |
print(eval_results) | |
if '__main__' == __name__: | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
赞。提个小建议哈,能不能做成git project,方便一些。