Created
November 29, 2018 10:06
-
-
Save elect000/130acbdb0a3779910082593db4296254 to your computer and use it in GitHub Desktop.
train-test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
config.yaml | |
training: | |
tfrecords: | |
- 'imgclassification/dataset/img_dataset_0_[len=600]_train.tfrecord' | |
- 'imgclassification/dataset/img_dataset_1_[len=600]_train.tfrecord' | |
length: | |
- 600 | |
- 600 | |
validation: | |
tfrecords: | |
- 'imgclassification/dataset/img_dataset_0_[len=30]_test.tfrecord' | |
- 'imgclassification/dataset/img_dataset_1_[len=30]_test.tfrecord' | |
length: | |
- 30 | |
- 30 | |
keep_prob: 0.8 | |
train_dir: 'imgclassification/train_log' | |
num_threads: 4 | |
""" | |
import math | |
import tensorflow as tf | |
import yaml | |
from tqdm import tqdm | |
def inference(images, config, num_class): | |
""" | |
:param images: Tensorflow's float tensor [batch_size x image_size x image_size x image_channel] | |
:param config | |
:param num_class 2 or some integer | |
:return: | |
""" | |
with tf.variable_scope('conv1') as scope: | |
conv = tf.layers.conv2d( | |
inputs=images, | |
filters=32, | |
kernel_size=[3, 3], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
conv = tf.layers.conv2d( | |
inputs=conv, | |
filters=64, | |
kernel_size=[3, 3], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') | |
drop = tf.layers.dropout(pool, rate=config['keep_prob'], name=scope.name) | |
with tf.variable_scope('conv2') as scope: | |
conv = tf.layers.conv2d( | |
inputs=drop, | |
filters=128, | |
kernel_size=[3, 3], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') | |
conv = tf.layers.conv2d( | |
inputs=pool, | |
filters=128, | |
kernel_size=[2, 2], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') | |
drop = tf.layers.dropout(pool, rate=0.25, name=scope.name) | |
conv = tf.layers.conv2d( | |
inputs=drop, | |
filters=128, | |
kernel_size=[2, 2], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') | |
conv = tf.layers.conv2d( | |
inputs=pool, | |
filters=128, | |
kernel_size=[2, 2], | |
padding='SAME', | |
activation=tf.nn.relu | |
) | |
pool = tf.layers.max_pooling2d(conv, pool_size=[2, 2], strides=2, padding='SAME') | |
drop = tf.layers.dropout(pool, rate=0.25, name=scope.name) | |
print(drop.shape) | |
with tf.variable_scope('fully_connected') as scope: | |
flat = tf.reshape(drop, [-1, 1 * 1 * 128]) | |
fc = tf.layers.dense(inputs=flat, units=1500, activation=tf.nn.relu) | |
drop = tf.layers.dropout(fc, rate=0.5) | |
softmax = tf.layers.dense(inputs=drop, units=num_class, activation=tf.nn.softmax, name=scope.name) | |
return softmax | |
def loss(logits, labels, weights): | |
""" | |
:param weights: | |
:param logits: | |
:param labels: | |
:return: | |
""" | |
class_weights = weights | |
weights = tf.reduce_sum(class_weights * labels, axis=1) | |
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits) | |
weighted_losses = unweighted_losses * weights | |
loss = tf.reduce_mean(weighted_losses) | |
tf.summary.scalar('cross_entropy', loss) | |
return loss | |
def training(loss, learning_rate): | |
""" | |
:param loss: | |
:param learning_rate: | |
:return: | |
""" | |
train_step = tf.train.AdamOptimizer(learning_rate=learning_rate, | |
epsilon=1e-08).minimize(loss) | |
return train_step | |
def accuracy(logits, labels): | |
""" | |
:param logits: | |
:param labels: | |
:return: | |
""" | |
current_prediction = tf.equal(tf.argmax(logits, axis=1), tf.argmax(labels, axis=1)) | |
accuracy = tf.reduce_mean(tf.cast(current_prediction, tf.float32)) | |
return accuracy | |
def get_tfrecord_serialized(tfrecord_path): | |
reader = tf.TFRecordReader() | |
tfrecord_file_queue = tf.train.string_input_producer([tfrecord_path], name='queue') | |
_, tfrecord_serialized = reader.read(tfrecord_file_queue) | |
return tfrecord_serialized | |
def parse_records(dataset): | |
features = { | |
'label': tf.FixedLenFeature((), tf.string), | |
'image': tf.FixedLenFeature([], tf.string) | |
} | |
parsed_features = tf.parse_single_example(dataset, features=features) | |
return parsed_features['label'], parsed_features['image'] | |
def _read_images(root_config): | |
def read_images(label, image): | |
label = tf.decode_raw(label, tf.float32) | |
label = tf.reshape(label, shape=[root_config['Model']['num_class']]) | |
image = tf.decode_raw(image, tf.float32) | |
image = tf.reshape(image, | |
shape=[root_config['Image']['image_size'], | |
root_config['Image']['image_size'], | |
root_config['Image']['image_channel']]) | |
return image, label | |
return read_images | |
def create_dataset_iterator(config, root_config): | |
read_image = _read_images(root_config) | |
training_dataset = tf.data.TFRecordDataset(config['training']['tfrecords']) | |
training_dataset = training_dataset.map(parse_records, config['num_threads']) | |
training_dataset = training_dataset.map(read_image, config['num_threads']) | |
training_dataset = training_dataset \ | |
.batch(root_config['Model']['batch_size']) \ | |
.shuffle(sum(config['training']['length'])) \ | |
.repeat(root_config['Model']['epoch']) | |
training_iterator = tf.data.Iterator.from_structure(training_dataset.output_types, | |
training_dataset.output_shapes) | |
validation_dataset = tf.data.TFRecordDataset(config['validation']['tfrecords']) | |
validation_dataset = validation_dataset.map(parse_records, config['num_threads']) | |
validation_dataset = validation_dataset.map(read_image, config['num_threads']) | |
validation_dataset = validation_dataset \ | |
.batch(sum(config['validation']['length'])) \ | |
.shuffle(sum(config['validation']['length'])) \ | |
.repeat(-1) | |
validation_iterator = tf.data.Iterator.from_structure(validation_dataset.output_types, | |
validation_dataset.output_shapes) | |
return training_dataset, training_iterator, validation_dataset, validation_iterator | |
def main(): | |
with open('imgclassification/model/config.yaml', 'r', encoding='utf-8') as yml: | |
config = yaml.load(yml) | |
with open('imgclassification/config.yaml', 'r', encoding='utf-8') as yml: | |
root_config = yaml.load(yml) | |
training_dataset, training_iterator, validation_dataset, validation_iterator = create_dataset_iterator(config, | |
root_config) | |
train_init_op = training_iterator.make_initializer(training_dataset) | |
valid_init_op = validation_iterator.make_initializer(validation_dataset) | |
training_batch = training_iterator.get_next() | |
validation_batch = validation_iterator.get_next() | |
images_placeholder = tf.placeholder(tf.float32, shape=(None, | |
root_config['Image']['image_size'], | |
root_config['Image']['image_size'], | |
root_config['Image']['image_channel'])) | |
labels_placeholder = tf.placeholder(tf.float32, shape=(None, | |
root_config['Model']['num_class'])) | |
weights = tf.constant([1.0, 1.0]) | |
logits = inference(images_placeholder, config, root_config['Model']['num_class']) | |
loss_value = loss(logits, labels_placeholder, weights=weights) | |
acc = accuracy(logits, labels_placeholder) | |
learning_rate = 1e-4 | |
with tf.name_scope('train'): | |
train_op = training(loss_value, learning_rate) | |
acc_summary_train_op = tf.summary.scalar('train_acc', acc) | |
loss_summary_train_op = tf.summary.scalar('train_loss', loss_value) | |
with tf.name_scope('valudation'): | |
acc_summary_val_op = tf.summary.scalar('val_acc', acc) | |
loss_summary_val_op = tf.summary.scalar('val_loss', loss_value) | |
# summary_op = tf.summary.merge_all() | |
print('[INFO]: CREATE SESSION') | |
with tf.Session() as sess: | |
sess.run(train_init_op) | |
sess.run(valid_init_op) | |
summary_writer = tf.summary.FileWriter(config['train_dir'], sess.graph) | |
sess.run(tf.global_variables_initializer()) | |
step = 0 | |
while True: | |
try: | |
step += 1 | |
for i in tqdm( | |
range(math.floor(sum(config['training']['length']) / root_config['Model']['batch_size']) - 1)): | |
images, labels = sess.run(training_batch) | |
sess.run(train_op, feed_dict={ | |
images_placeholder: images, | |
labels_placeholder: labels | |
}) | |
images, labels = sess.run(training_batch) | |
res = sess.run([acc_summary_train_op, loss_summary_train_op], feed_dict={ | |
images_placeholder: images, | |
labels_placeholder: labels, | |
}) | |
for j in range(len(res)): | |
summary_writer.add_summary(res[j], step) | |
images, labels = sess.run(validation_batch) | |
res = sess.run([acc_summary_val_op, loss_summary_val_op], feed_dict={ | |
images_placeholder: images, | |
labels_placeholder: labels, | |
}) | |
for j in range(len(res)): | |
summary_writer.add_summary(res[j], step) | |
except tf.errors.OutOfRangeError: | |
break | |
print("[INFO] TRAINING FINISH") | |
# saver.save(sess, 'imgclassification/' + 'model.ckpt') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment