Last active
September 18, 2021 19:24
-
-
Save kyoto-cheng/7652664f71078d57d536bac893c4b0d0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import resources | |
%matplotlib inline | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
from torchvision import transforms, models | |
# get the "features" portion of VGG19 (we will not need the "classifier" portion) | |
vgg = models.vgg19(pretrained=True).features | |
# freeze all VGG parameters since we're only optimizing the target image | |
for param in vgg.parameters(): | |
param.requires_grad_(False) | |
# move the model to GPU, if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
vgg.to(device) | |
# Load in and transform an image, making sure the image is <= 400 pixels in the x-y dims. | |
def load_image(img_path, max_size=400, shape=None): | |
image = Image.open(img_path).convert('RGB') | |
# large images will slow down processing | |
if max(image.size) > max_size: | |
size = max_size | |
else: | |
size = max(image.size) | |
if shape is not None: | |
size = shape | |
in_transform = transforms.Compose([ | |
transforms.Resize(size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225))]) | |
# discard the transparent, alpha channel (that's the :3) and add the batch dimension | |
image = in_transform(image)[:3,:,:].unsqueeze(0) | |
return image | |
# load in content and style image | |
content = load_image('images/your_content_image').to(device) | |
# Resize style to match content, makes code easier | |
style = load_image('images/your_style_image', shape=content.shape[-2:]).to(device) | |
# helper function for un-normalizing an image | |
# and converting it from a Tensor image to a NumPy image for display | |
def im_convert(tensor): | |
""" Display a tensor as an image. """ | |
image = tensor.to("cpu").clone().detach() | |
image = image.numpy().squeeze() | |
image = image.transpose(1,2,0) | |
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)) | |
image = image.clip(0, 1) | |
return image | |
# display the images | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
# content and style ims side-by-side | |
ax1.imshow(im_convert(content)) | |
ax2.imshow(im_convert(style)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment