Last active
November 15, 2023 05:35
-
-
Save humpydonkey/b1e5b4cde32738913d99a50a957a7b4b to your computer and use it in GitHub Desktop.
Visualize a list of images in a grid
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import PIL.Image | |
from PIL import ImageFont | |
from torchvision.utils import make_grid as torch_make_grid | |
from torchvision.transforms.functional import pil_to_tensor, to_pil_image | |
from pathlib import Path | |
import platform | |
def make_grid(images: list[PIL.Image.Image | str]): | |
if isinstance(images[0], str): | |
images = [PIL.Image.open(img) for img in images] | |
print(f"Making grid of {len(images)} images") | |
sample_size = min(300, len(images)) | |
images = images[:sample_size] | |
for img in images: | |
img.thumbnail((256, 256)) | |
img_torch = [pil_to_tensor(img) for img in images] | |
return to_pil_image(torch_make_grid(img_torch)) | |
def add_label(image: PIL.Image.Image, label: str, position: tuple[int, int] | None = None): | |
# Below path is only tested for Ubuntu 22.04 | |
if platform.system() == "Linux": | |
font_path = "/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf" | |
elif platform.system() == "Darwin": | |
font_path = "/System/Library/Fonts/Supplemental/Arial.ttf" | |
else: | |
raise ValueError(f"Unsupported platform: {platform.system()}") | |
font = ImageFont.truetype(font_path, 28) | |
if position is None: | |
position = (10, image.height - 70) | |
PIL.ImageDraw.Draw(image).text(position, label, fill=(0, 0, 0), font=font) | |
return image | |
# display(add_label(make_grid(output), "test")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment