Last active
February 25, 2023 06:07
-
-
Save takuma104/9b8499e3cdbb5cae3335eb1843b485d6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pytest | |
import PIL.Image | |
import numpy as np | |
from diffusers.utils import ( | |
PIL_INTERPOLATION, | |
) | |
# `batch` = batch_size * num_images_per_prompt | |
# When the input is a single image or a tensor with b==1, repeat it the `batch` times. | |
# When the input is multiple images or tensors with b!=1, | |
# repeat it as necessary to correspond with the specified `batch` size. | |
def preprocess(image, width, height, batch): | |
def batch_adjust(image, batch): | |
assert image.dim() == 4 | |
if image.shape[0] != batch: | |
return image.repeat(batch // image.shape[0], 1, 1, 1) | |
else: | |
return image | |
if isinstance(image, torch.Tensor): | |
return batch_adjust(image, batch) | |
elif isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[:, :, :, ::-1] # RGB -> BGR | |
image = image.transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image.copy()) # copy: ::-1 workaround | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
return batch_adjust(image, batch) | |
@pytest.mark.parametrize("w,h,b", [(512,512,1), (512,512,2), (512,768,1), (512,768,2)]) | |
def test_single_pil_image_input(w, h, b): | |
image = PIL.Image.new(mode='RGB', size=(w, h)) | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (b, 3, h, w) | |
@pytest.mark.parametrize("w,h,b", [(512,512,1), (512,512,2), (512,768,1), (512,768,2)]) | |
def test_single_tensor_input(w, h, b): | |
image = torch.randn((1, 3, h, w)) | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (b, 3, h, w) | |
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)]) | |
def test_array_pil_image_input(w, h, b): | |
image = [PIL.Image.new(mode='RGB', size=(w, h), color=(0,0,0)), | |
PIL.Image.new(mode='RGB', size=(w, h), color=(255,255,255))] | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (2, 3, h, w) | |
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9 | |
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)]) | |
def test_array_tensor_input(w, h, b): | |
image = [torch.zeros((1, 3, h, w)), | |
torch.ones((1, 3, h, w))] | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (2, 3, h, w) | |
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9 | |
@pytest.mark.parametrize("w,h,b", [(512,512,2), (512,768,2)]) | |
def test_batched_tensor_input(w, h, b): | |
image = torch.cat([torch.zeros((1, 3, h, w)), torch.ones((1, 3, h, w))], dim=0) | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (2, 3, h, w) | |
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9 | |
@pytest.mark.parametrize("w,h,b", [(512,512,4)]) | |
def test_array_pil_image_input2(w, h, b): | |
image = [PIL.Image.new(mode='RGB', size=(w, h), color=(0,0,0)), | |
PIL.Image.new(mode='RGB', size=(w, h), color=(255,255,255))] | |
image = preprocess(image, w, h, batch=b) | |
assert isinstance(image, torch.Tensor) | |
assert image.shape == (4, 3, h, w) | |
assert image[0,0,0,0] < .1 and image[1,0,0,0] > .9 and image[2,0,0,0] < .1 and image[3,0,0,0] > .9 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment