Skip to content

Instantly share code, notes, and snippets.

@takuma104
Last active February 25, 2023 06:07
Show Gist options
  • Save takuma104/9b8499e3cdbb5cae3335eb1843b485d6 to your computer and use it in GitHub Desktop.
Save takuma104/9b8499e3cdbb5cae3335eb1843b485d6 to your computer and use it in GitHub Desktop.
import torch
import pytest
import PIL.Image
import numpy as np
from diffusers.utils import (
PIL_INTERPOLATION,
)
# `batch` = batch_size * num_images_per_prompt
# When the input is a single image or a tensor with b==1, repeat it the `batch` times.
# When the input is multiple images or tensors with b!=1,
# repeat it as necessary to correspond with the specified `batch` size.
def preprocess(image, width, height, batch):
def batch_adjust(image, batch):
assert image.dim() == 4
if image.shape[0] != batch:
return image.repeat(batch // image.shape[0], 1, 1, 1)
else:
return image
if isinstance(image, torch.Tensor):
return batch_adjust(image, batch)
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, :, ::-1] # RGB -> BGR
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image.copy()) # copy: ::-1 workaround
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return batch_adjust(image, batch)
@pytest.mark.parametrize("w,h,b", [(512,512,1), (512,512,2), (512,768,1), (512,768,2)])
def test_single_pil_image_input(w, h, b):
image = PIL.Image.new(mode='RGB', size=(w, h))
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (b, 3, h, w)
@pytest.mark.parametrize("w,h,b", [(512,512,1), (512,512,2), (512,768,1), (512,768,2)])
def test_single_tensor_input(w, h, b):
image = torch.randn((1, 3, h, w))
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (b, 3, h, w)
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)])
def test_array_pil_image_input(w, h, b):
image = [PIL.Image.new(mode='RGB', size=(w, h), color=(0,0,0)),
PIL.Image.new(mode='RGB', size=(w, h), color=(255,255,255))]
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (2, 3, h, w)
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)])
def test_array_tensor_input(w, h, b):
image = [torch.zeros((1, 3, h, w)),
torch.ones((1, 3, h, w))]
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (2, 3, h, w)
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)])
def test_batched_tensor_input(w, h, b):
image = torch.cat([torch.zeros((1, 3, h, w)), torch.ones((1, 3, h, w))], dim=0)
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (2, 3, h, w)
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9
@pytest.mark.parametrize("w,h,b", [(512,512,4)])
def test_array_pil_image_input2(w, h, b):
image = [PIL.Image.new(mode='RGB', size=(w, h), color=(0,0,0)),
PIL.Image.new(mode='RGB', size=(w, h), color=(255,255,255))]
image = preprocess(image, w, h, batch=b)
assert isinstance(image, torch.Tensor)
assert image.shape == (4, 3, h, w)
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9 and image[2,0,0,0] < .1 and image[3,0,0,0] > .9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment