Created
February 20, 2020 14:40
-
-
Save dlibenzi/c9868a1090f6f8ef9d79d2cfcbadd8ab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
import numpy as np | |
import hashlib | |
import os | |
import sys | |
import torch | |
import torch_xla.utils.tf_record_reader as tfrr | |
a = """ | |
image/class/label tensor([82]) | |
image/class/synset n01796340 | |
image/channels tensor([3]) | |
image/object/bbox/label tensor([], dtype=torch.int64) | |
image/width tensor([900]) | |
image/format JPEG | |
image/height tensor([600]) | |
image/class/text ptarmigan | |
image/object/bbox/ymin tensor([]) | |
image/encoded tensor([ -1, -40, -1, ..., -30, -1, -39], dtype=torch.int8) | |
image/object/bbox/ymax tensor([]) | |
image/object/bbox/xmin tensor([]) | |
image/filename n01796340_812.JPEG | |
image/object/bbox/xmax tensor([]) | |
image/colorspace RGB | |
""" | |
def decode(ex): | |
w = ex['image/width'].item() | |
h = ex['image/height'].item() | |
imgb = ex['image/encoded'].numpy().tobytes() | |
m = hashlib.md5() | |
m.update(imgb) | |
print('HASH = {}'.format(m.hexdigest())) | |
image = Image.frombytes(ex['image/colorspace'], (w, h), imgb, | |
ex['image/format'].lower(), 'RGB', None) | |
npa = np.asarray(image) | |
return torch.from_numpy(npa), image | |
def readem(path, img_path=None): | |
transforms = { | |
'image/filename': 'STR', | |
'image/class/synset': 'STR', | |
'image/format': 'STR', | |
'image/class/text': 'STR', | |
'image/colorspace': 'STR', | |
} | |
r = tfrr.TfRecordReader(path, compression='', transforms=transforms) | |
count = 0 | |
while True: | |
ex = r.read_example() | |
if not ex: | |
break | |
print('\n') | |
for lbl, data in ex.items(): | |
print('{}\t{}'.format(lbl, data)) | |
img_tensor, image = decode(ex) | |
if img_path: | |
image.save(os.path.join(img_path, str(count) + '.jpg')) | |
count += 1 | |
print('\n\nDecoded {} samples'.format(count)) | |
readem(sys.argv[1], img_path='/tmp/tf_images') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment