Created
December 29, 2020 09:52
-
-
Save mengwangk/d4dcea07a9c9cc9875c24ae82cc8b5ba to your computer and use it in GitHub Desktop.
Serving ML Models - backend PGAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchvision.transforms as Transforms | |
| use_gpu = True if torch.cuda.is_available() else False | |
| # trained on high-quality celebrity faces "celebA" dataset | |
| # this model outputs 512 x 512 pixel images | |
| model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', | |
| 'PGAN', model_name='celebAHQ-512', | |
| pretrained=True, useGPU=use_gpu) | |
| # this model outputs 256 x 256 pixel images | |
| # model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', | |
| # 'PGAN', model_name='celebAHQ-256', | |
| # pretrained=True, useGPU=use_gpu) | |
| num_images = 1 | |
| def pgan(): | |
| noise, _ = model.buildNoiseData(num_images) | |
| with torch.no_grad(): | |
| generated_images = model.test(noise) | |
| transform = Transforms.Compose([Transforms.Normalize((-1., -1., -1.), (2, 2, 2)), | |
| Transforms.ToPILImage()]) | |
| generated_images = generated_images[0] | |
| generated_images = transform(generated_images.clamp(min=-1, max=1)) | |
| return generated_images |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment