Last active
July 8, 2020 16:23
-
-
Save andfoy/c9b75470c6557771d55d96ed40dbafd4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import unittest | |
import sys | |
import torch | |
import torchvision | |
from PIL import Image | |
from torchvision.io.image import read_png, decode_png, read_jpeg, decode_jpeg | |
import numpy as np | |
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") | |
def get_images(directory, img_ext): | |
assert os.path.isdir(directory) | |
for root, _, files in os.walk(directory): | |
for fl in files: | |
_, ext = os.path.splitext(fl) | |
if ext == img_ext: | |
yield os.path.join(root, fl) | |
class ImageTester(unittest.TestCase): | |
def test_read_jpeg(self): | |
for img_path in get_images(IMAGE_ROOT, ".jpg"): | |
img_pil = torch.from_numpy(np.array(Image.open(img_path))) | |
img_ljpeg = read_jpeg(img_path) | |
torch.save(img_pil, 'pil_img_dump.pth') | |
norm = img_ljpeg.shape[0] * img_ljpeg.shape[1] * img_ljpeg.shape[2] * 255 | |
err = torch.abs(img_ljpeg.flatten().float() - img_pil.flatten().float()).sum().float() / (norm) | |
print(err) | |
diff = (img_ljpeg.float() - img_pil.float()).abs().max() | |
print(diff) | |
self.assertTrue(img_ljpeg.equal(img_pil)) | |
def test_decode_jpeg(self): | |
for img_path in get_images(IMAGE_ROOT, ".jpg"): | |
img_pil = torch.from_numpy(np.array(Image.open(img_path))) | |
size = os.path.getsize(img_path) | |
img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size)) | |
self.assertTrue(img_ljpeg.equal(img_pil)) | |
with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."): | |
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8)) | |
with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."): | |
decode_jpeg(torch.empty((100, ), dtype=torch.float16)) | |
with self.assertRaises(RuntimeError): | |
decode_jpeg(torch.empty((100), dtype=torch.uint8)) | |
def test_read_png(self): | |
# Check across .png | |
for img_path in get_images(IMAGE_DIR, ".png"): | |
img_pil = torch.from_numpy(np.array(Image.open(img_path))) | |
img_lpng = read_png(img_path) | |
self.assertTrue(img_lpng.equal(img_pil)) | |
def test_decode_png(self): | |
for img_path in get_images(IMAGE_DIR, ".png"): | |
img_pil = torch.from_numpy(np.array(Image.open(img_path))) | |
size = os.path.getsize(img_path) | |
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size)) | |
self.assertTrue(img_lpng.equal(img_pil)) | |
with self.assertRaises(ValueError): | |
decode_png(torch.empty((), dtype=torch.uint8)) | |
with self.assertRaises(RuntimeError): | |
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment