Skip to content

Instantly share code, notes, and snippets.

@luuil
Last active November 18, 2021 14:32
Show Gist options
  • Save luuil/6d17831019bfe40cc2181406c3741fce to your computer and use it in GitHub Desktop.
Save luuil/6d17831019bfe40cc2181406c3741fce to your computer and use it in GitHub Desktop.
TFRecord creation and inspection, such as records count and visualize etc.
import os
import tensorflow as tf
import numpy as np
from PIL import Image
import io
def bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def read_txt(file):
assert os.path.exists(file), f'not exists: {file}'
with open(file) as f:
lines = [line.rstrip() for line in f.readlines()]
return lines
def serialize_tfrecord_example(size=(256, 320), scale=1.2):
def _path(p1, p2):
pimg_linux = p1.replace('\\', '/') if is_linux() else p1
pmask_linux = p2.replace('\\', '/') if is_linux() else p2
return pimg_linux, pmask_linux
def _to_jpg_bytes(pil: Image.Image):
with io.BytesIO() as buf:
pil.save(buf, format='JPEG')
return buf.getvalue()
def _serialize(pimg: str, pmask: str) -> str:
pimg, pmask = _path(pimg, pmask)
sz = (np.array(size) * scale).astype(int)
img = Image.open(pimg)
img = Image.fromarray(np.asarray(img)[..., :3])# drop alpha if has one
if os.path.exists(pmask):
mask = Image.open(pmask).convert('L')
else:
print(f'gt not exist, create zeros for: {pimg}')
mask = np.zeros_like(np.asarray(img)[..., 0])
mask = Image.fromarray(mask)
img = img.resize(sz)
mask = mask.resize(sz)
w, h = img.size
feature = {
'height': int64_feature(h),
'width': int64_feature(w),
'image': bytes_feature(_to_jpg_bytes(img)), # must be jpg
'mask': bytes_feature(_to_jpg_bytes(mask)),
'image_path': bytes_feature(pimg.encode()),
'mask_path': bytes_feature(pmask.encode()),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
return _serialize
class TFRecordCreator(object):
def __init__(self, out, example_serialize_func: Callable, splits=1):
self._out = out
self._serialize_func = example_serialize_func
self._splits = splits
@staticmethod
def chunks(lst, n):
"""Yield successive n chunks from lst."""
k, m = divmod(len(lst), n)
for i in range(n):
yield lst[i*k+min(i, m):(i+1)*k+min(i+1, m)]
def _create_from_list(self, image_list, mask_list, out_tfr):
with tf.io.TFRecordWriter(out_tfr) as writer:
for i, m in zip(image_list, mask_list):
data = self._serialize_func(i, m)
writer.write(data)
def create_splits(self, image_list, mask_list):
img_lst_gen = self.chunks(image_list, self._splits)
msk_lst_gen = self.chunks(mask_list, self._splits)
if self._splits == 1:
tfr_outs = [self._out]
else:
name = os.path.splitext(self._out)[0]
tfr_outs = [f'{name}.{i}.tfrecord' for i in range(self._splits)]
for img_lst, msk_lst, tfr in zip(img_lst_gen, msk_lst_gen, tfr_outs):
self._create_from_list(img_lst, msk_lst, tfr)
def create_from_image_index_file(self, image_index_file, label_index_file=None):
image_list = read_txt(image_index_file)
label_file = os.path.splitext(image_index_file)[0] + '.label' if label_index_file is None else label_index_file
mask_list = read_txt(label_file)
if len(mask_list) == 0:
mask_list = [''] * len(image_list)
self.create_splits(image_list, mask_list)
def main_tfr_create(pair):
out = 'x.tfrecord'
image_file = 'x.txt'
mask_file = 'x.label'
create = TFRecordCreator(out, serialize_tfrecord_example(), 3)
create.create_from_image_index_file(image_file, mask_file)
if __name__ == '__main__':
main_tfr_create()
import os
from typing import Callable
import tensorflow as tf
class TFRecordInspector(object):
def __init__(self, tfr) -> None:
super().__init__()
assert os.path.exists(tfr), f'not exists: {tfr}'
self._tfr = tfr
def records_count(self) -> int:
c = 0
for _ in tf.io.tf_record_iterator(self._tfr):
c += 1
return c
def inspect(self, example_parse_func: Callable, batch=1):
dataset = tf.data.TFRecordDataset(self._tfr)
dataset = dataset.map(example_parse_func)
dataset = dataset.batch(batch)
data_iterater = dataset.make_one_shot_iterator()
data = data_iterater.get_next()
with tf.Session() as sess:
b = 0
while True:
try:
d = sess.run(data)
b += 1
print(f'batch({b}): {type(d)}')
except tf.errors.OutOfRangeError:
break
yield d
def parse_tfrecord_example(debug=False):
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'image': tf.io.FixedLenFeature([], tf.string),
'mask': tf.io.FixedLenFeature([], tf.string),
'image_path': tf.io.FixedLenFeature([], tf.string),
'mask_path': tf.io.FixedLenFeature([], tf.string),
}
def _decode(features):
""" decode image and mask
Args:
features: parsed tfr example
Return: img, mask
"""
image = tf.image.decode_image(features['image'])
annotation = tf.image.decode_image(features['mask'])
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image_shape = [height, width, 3]
annotation_shape = [height, width, 1]
image = tf.reshape(image, image_shape)
mask = tf.reshape(annotation, annotation_shape)
image = tf.cast(image, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
if debug:
image_path = features['image_path']
mask_path = features['mask_path']
return image, mask, (image_path, mask_path)
else:
return image, mask
def _pasre(example_proto):
parsed_example = tf.io.parse_single_example(example_proto, image_feature_description)
return _decode(parsed_example)
return _pasre
def main_tfrecord_parse():
import cv2
tfr = r"x.tfrecord"
debug = False
tfri = TFRecordInspector(tfr)
for d in tfri.inspect(parse_tfrecord_example(debug=debug), batch=2):
for item in zip(*d):
if debug:
i, m, p = item
print(p)
else:
i, m = item
cv2.imshow(f'img', cv2.cvtColor(np.uint8(i), cv2.COLOR_RGB2BGR))
cv2.imshow(f'mask', np.uint8(m))
cv2.waitKey(0)
def main_tfr_count():
file = r"x.tfrecord"
tfri = TFRecordInspector(file)
print(file, tfri.records_count())
if __name__ == '__main__':
main_tfrecord_parse()
main_tfr_count()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment