Skip to content

Instantly share code, notes, and snippets.

@solaris33
Last active May 6, 2019 05:48
Show Gist options
  • Save solaris33/d849eb7c98353984f52e02eaacafc326 to your computer and use it in GitHub Desktop.
Save solaris33/d849eb7c98353984f52e02eaacafc326 to your computer and use it in GitHub Desktop.
#-*- coding: utf-8 -*-
"""
TFRecords Example
Reference : https://www.tensorflow.org/tutorials/load_data/tf_records
Author : solaris33
Project URL : http://solarisailab.com/archives/2603
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
# value를 tf.Example에 대응되는 형태로 변환하기 위한 함수들
def _bytes_feature(value):
"""string / byte 타입을 받아서 byte list를 리턴합니다."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""float / double 타입을 받아서 float list를 리턴합니다."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""bool / enum / int / uint 타입을 받아서 int64 list를 리턴합니다."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
print(_bytes_feature(b'test_string'))
'''
bytes_list {
value: "test_string"
}
'''
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
'''
bytes_list {
value: "test_bytes"
}
'''
print(_float_feature(np.exp(1)))
'''
float_list {
value: 2.71828174591
}
'''
print(_int64_feature(True))
'''
int64_list {
value: 1
}
'''
print(_int64_feature(1))
'''
int64_list {
value: 1
}
'''
#.SerializeToString 함수를 이용해 binary string으로 변환할 수 있습니다.
feature = _float_feature(np.exp(1))
print(feature.SerializeToString())
def serialize_example(feature0, feature1, feature2, feature3):
"""
파일에 write하기 위한 tf.Example messeage를 생성합니다.
"""
# key - feature name
# value - tf.Example에 적합한 타입
# 형태의 dictionary를 생성합니다.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# tf.train.Example을 이용해서 Feature messeage를 생성합니다.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
# serialize_example 함수를 이용해서 binary string으로 serialize합니다.
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
print(serialized_example)
# serilization된 데이터를 tf.train.Example.FromString 메소드를 이용해서 decode합니다.
example_proto = tf.train.Example.FromString(serialized_example)
print(example_proto)
filename = 'test.tfrecord'
# observation 횟수
n_observations = int(1e4)
# boolean feature - [False or True]
feature0 = np.random.choice([False, True], n_observations)
# integer feature - [0 .. 4]
feature1 = np.random.randint(0, 5, n_observations)
# string feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# float feature - from standard normal distribution
feature3 = np.random.randn(n_observations)
# tf.Example 데이터를 tfrecord 파일에 write합니다.
with tf.python_io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
record_iterator = tf.python_io.tf_record_iterator(path=filename)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
print(example)
# 데모 목적으로 1번째 값만 출력하고 종료합니다.
break
# example object를 python dictionary 형태로 mapping합니다.
print(dict(example.features.feature))
print(example.features.feature['feature3'])
'''
float_list {
value: 1.63795161247
}
'''
print(example.features.feature['feature3'].float_list.value)
'''
[1.6379516124725342]
'''
# sample 이미지를 다운로드하고 화면에 띄웁니다.
cat_in_snow = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
cat_in_snow_image = Image.open(cat_in_snow)
plt.imshow(cat_in_snow_image)
plt.show()
williamsburg_bridge_image = Image.open(williamsburg_bridge)
plt.imshow(williamsburg_bridge_image, cmap='gray')
plt.show()
# 예제로 cat 이미지의 image 정보(hegiht, widht, depth, image_raw)와 레이블 정보(label)를 저장합니다.
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# InteractiveSession을 엽니다.
sess = tf.InteractiveSession()
# 적절한 dictionary 값을 생성합니다.
def image_example(image_string, label):
image_shape = sess.run(tf.image.decode_jpeg(image_string)).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
# 저장된 값을 15번째 라인까지 출력합니다.
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
# 예제 이미지들(cat, bridge)을 images.tfrecords 파일에 저장합니다.
with tf.python_io.TFRecordWriter('images.tfrecords') as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
# feature를 묘사하는 dictionary를 생성합니다.
image_feature_description = {
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# tf.Example을 parsing합니다.
return tf.parse_single_example(example_proto, image_feature_description)
# tf.data.TFRecordDataset을 생성합니다.
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
print(parsed_image_dataset)
# tf.data.TFRecordDataset을 iterator로 돌면서 이미지를 화면에 출력합니다.
iterator = parsed_image_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
try:
while True:
image_features = sess.run(next_element)
# height, width, depth 정보값을 출력합니다.
print('height :', image_features['height'])
print('width :', image_features['width'])
print('depth :', image_features['depth'])
# binary string으로 encoding된 이미지를 PIL Image 형태로 변환합니다.
encoded_jpg_io = io.BytesIO(image_features['image_raw'])
image = Image.open(encoded_jpg_io)
if image_features['depth'] == 1:
plt.imshow(image, cmap='gray')
plt.show()
else:
plt.imshow(image)
plt.show()
except tf.errors.OutOfRangeError:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment