Last active
November 18, 2021 14:32
-
-
Save luuil/6d17831019bfe40cc2181406c3741fce to your computer and use it in GitHub Desktop.
TFRecord creation and inspection, such as records count and visualize etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import io | |
def bytes_feature(value): | |
"""Returns a bytes_list from a string / byte.""" | |
if isinstance(value, type(tf.constant(0))): | |
value = value.numpy() # BytesList won't unpack a string from an EagerTensor. | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def int64_feature(value): | |
"""Returns an int64_list from a bool / enum / int / uint.""" | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def read_txt(file): | |
assert os.path.exists(file), f'not exists: {file}' | |
with open(file) as f: | |
lines = [line.rstrip() for line in f.readlines()] | |
return lines | |
def serialize_tfrecord_example(size=(256, 320), scale=1.2): | |
def _path(p1, p2): | |
pimg_linux = p1.replace('\\', '/') if is_linux() else p1 | |
pmask_linux = p2.replace('\\', '/') if is_linux() else p2 | |
return pimg_linux, pmask_linux | |
def _to_jpg_bytes(pil: Image.Image): | |
with io.BytesIO() as buf: | |
pil.save(buf, format='JPEG') | |
return buf.getvalue() | |
def _serialize(pimg: str, pmask: str) -> str: | |
pimg, pmask = _path(pimg, pmask) | |
sz = (np.array(size) * scale).astype(int) | |
img = Image.open(pimg) | |
img = Image.fromarray(np.asarray(img)[..., :3])# drop alpha if has one | |
if os.path.exists(pmask): | |
mask = Image.open(pmask).convert('L') | |
else: | |
print(f'gt not exist, create zeros for: {pimg}') | |
mask = np.zeros_like(np.asarray(img)[..., 0]) | |
mask = Image.fromarray(mask) | |
img = img.resize(sz) | |
mask = mask.resize(sz) | |
w, h = img.size | |
feature = { | |
'height': int64_feature(h), | |
'width': int64_feature(w), | |
'image': bytes_feature(_to_jpg_bytes(img)), # must be jpg | |
'mask': bytes_feature(_to_jpg_bytes(mask)), | |
'image_path': bytes_feature(pimg.encode()), | |
'mask_path': bytes_feature(pmask.encode()), | |
} | |
example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) | |
return example_proto.SerializeToString() | |
return _serialize | |
class TFRecordCreator(object): | |
def __init__(self, out, example_serialize_func: Callable, splits=1): | |
self._out = out | |
self._serialize_func = example_serialize_func | |
self._splits = splits | |
@staticmethod | |
def chunks(lst, n): | |
"""Yield successive n chunks from lst.""" | |
k, m = divmod(len(lst), n) | |
for i in range(n): | |
yield lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] | |
def _create_from_list(self, image_list, mask_list, out_tfr): | |
with tf.io.TFRecordWriter(out_tfr) as writer: | |
for i, m in zip(image_list, mask_list): | |
data = self._serialize_func(i, m) | |
writer.write(data) | |
def create_splits(self, image_list, mask_list): | |
img_lst_gen = self.chunks(image_list, self._splits) | |
msk_lst_gen = self.chunks(mask_list, self._splits) | |
if self._splits == 1: | |
tfr_outs = [self._out] | |
else: | |
name = os.path.splitext(self._out)[0] | |
tfr_outs = [f'{name}.{i}.tfrecord' for i in range(self._splits)] | |
for img_lst, msk_lst, tfr in zip(img_lst_gen, msk_lst_gen, tfr_outs): | |
self._create_from_list(img_lst, msk_lst, tfr) | |
def create_from_image_index_file(self, image_index_file, label_index_file=None): | |
image_list = read_txt(image_index_file) | |
label_file = os.path.splitext(image_index_file)[0] + '.label' if label_index_file is None else label_index_file | |
mask_list = read_txt(label_file) | |
if len(mask_list) == 0: | |
mask_list = [''] * len(image_list) | |
self.create_splits(image_list, mask_list) | |
def main_tfr_create(pair): | |
out = 'x.tfrecord' | |
image_file = 'x.txt' | |
mask_file = 'x.label' | |
create = TFRecordCreator(out, serialize_tfrecord_example(), 3) | |
create.create_from_image_index_file(image_file, mask_file) | |
if __name__ == '__main__': | |
main_tfr_create() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Callable | |
import tensorflow as tf | |
class TFRecordInspector(object): | |
def __init__(self, tfr) -> None: | |
super().__init__() | |
assert os.path.exists(tfr), f'not exists: {tfr}' | |
self._tfr = tfr | |
def records_count(self) -> int: | |
c = 0 | |
for _ in tf.io.tf_record_iterator(self._tfr): | |
c += 1 | |
return c | |
def inspect(self, example_parse_func: Callable, batch=1): | |
dataset = tf.data.TFRecordDataset(self._tfr) | |
dataset = dataset.map(example_parse_func) | |
dataset = dataset.batch(batch) | |
data_iterater = dataset.make_one_shot_iterator() | |
data = data_iterater.get_next() | |
with tf.Session() as sess: | |
b = 0 | |
while True: | |
try: | |
d = sess.run(data) | |
b += 1 | |
print(f'batch({b}): {type(d)}') | |
except tf.errors.OutOfRangeError: | |
break | |
yield d | |
def parse_tfrecord_example(debug=False): | |
image_feature_description = { | |
'height': tf.io.FixedLenFeature([], tf.int64), | |
'width': tf.io.FixedLenFeature([], tf.int64), | |
'image': tf.io.FixedLenFeature([], tf.string), | |
'mask': tf.io.FixedLenFeature([], tf.string), | |
'image_path': tf.io.FixedLenFeature([], tf.string), | |
'mask_path': tf.io.FixedLenFeature([], tf.string), | |
} | |
def _decode(features): | |
""" decode image and mask | |
Args: | |
features: parsed tfr example | |
Return: img, mask | |
""" | |
image = tf.image.decode_image(features['image']) | |
annotation = tf.image.decode_image(features['mask']) | |
height = tf.cast(features['height'], tf.int32) | |
width = tf.cast(features['width'], tf.int32) | |
image_shape = [height, width, 3] | |
annotation_shape = [height, width, 1] | |
image = tf.reshape(image, image_shape) | |
mask = tf.reshape(annotation, annotation_shape) | |
image = tf.cast(image, dtype=tf.float32) | |
mask = tf.cast(mask, dtype=tf.float32) | |
if debug: | |
image_path = features['image_path'] | |
mask_path = features['mask_path'] | |
return image, mask, (image_path, mask_path) | |
else: | |
return image, mask | |
def _pasre(example_proto): | |
parsed_example = tf.io.parse_single_example(example_proto, image_feature_description) | |
return _decode(parsed_example) | |
return _pasre | |
def main_tfrecord_parse(): | |
import cv2 | |
tfr = r"x.tfrecord" | |
debug = False | |
tfri = TFRecordInspector(tfr) | |
for d in tfri.inspect(parse_tfrecord_example(debug=debug), batch=2): | |
for item in zip(*d): | |
if debug: | |
i, m, p = item | |
print(p) | |
else: | |
i, m = item | |
cv2.imshow(f'img', cv2.cvtColor(np.uint8(i), cv2.COLOR_RGB2BGR)) | |
cv2.imshow(f'mask', np.uint8(m)) | |
cv2.waitKey(0) | |
def main_tfr_count(): | |
file = r"x.tfrecord" | |
tfri = TFRecordInspector(file) | |
print(file, tfri.records_count()) | |
if __name__ == '__main__': | |
main_tfrecord_parse() | |
main_tfr_count() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment