Skip to content

Instantly share code, notes, and snippets.

@kyoto-cheng
Last active September 18, 2021 19:24
Show Gist options
  • Save kyoto-cheng/7652664f71078d57d536bac893c4b0d0 to your computer and use it in GitHub Desktop.
Save kyoto-cheng/7652664f71078d57d536bac893c4b0d0 to your computer and use it in GitHub Desktop.
# import resources
%matplotlib inline
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torchvision import transforms, models
# get the "features" portion of VGG19 (we will not need the "classifier" portion)
vgg = models.vgg19(pretrained=True).features
# freeze all VGG parameters since we're only optimizing the target image
for param in vgg.parameters():
param.requires_grad_(False)
# move the model to GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
# Load in and transform an image, making sure the image is <= 400 pixels in the x-y dims.
def load_image(img_path, max_size=400, shape=None):
image = Image.open(img_path).convert('RGB')
# large images will slow down processing
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# discard the transparent, alpha channel (that's the :3) and add the batch dimension
image = in_transform(image)[:3,:,:].unsqueeze(0)
return image
# load in content and style image
content = load_image('images/your_content_image').to(device)
# Resize style to match content, makes code easier
style = load_image('images/your_style_image', shape=content.shape[-2:]).to(device)
# helper function for un-normalizing an image
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
""" Display a tensor as an image. """
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
# display the images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
# content and style ims side-by-side
ax1.imshow(im_convert(content))
ax2.imshow(im_convert(style))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment