Last active
March 27, 2018 05:32
-
-
Save CharStiles/e40cbade2534d5fb41c5ebd594d03cb7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
based off of test.py from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix | |
Thanks to Aman Tiwari for the help. | |
to run: | |
python ServerToProcessFaceSketch.py --dataroot ./darta --name face_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction AtoB --dataset_mode aligned --norm batch | |
''' | |
import os | |
import io | |
import time | |
import zmq | |
import random | |
import sys | |
import base64 | |
import util.util as util | |
from options.test_options import TestOptions | |
from data.data_loader import CreateDataLoader | |
from models.models import create_model | |
from util.visualizer import Visualizer | |
import numpy as np | |
from PIL import Image | |
from scipy.misc import imresize | |
from torch.autograd import Variable | |
from torchvision import transforms | |
prepare = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
port = "8080" | |
context = zmq.Context() | |
socket = context.socket(zmq.REP) | |
socket.bind("tcp://*:%s" % port) | |
opt = TestOptions().parse() | |
opt.nThreads = 1 # test code only supports nThreads = 1 | |
opt.batchSize = 1 # test code only supports batchSize = 1 | |
opt.serial_batches = True # no shuffle | |
opt.no_flip = True # no flip | |
data_loader = CreateDataLoader(opt) | |
dataset = data_loader.load_data() | |
model = create_model(opt) | |
f = open("./darta/test/img.jpg",'rb') | |
ff = f.read() | |
f.close() | |
def forward(model, inp): | |
tensor = prepare(inp) | |
inp_var = Variable(tensor.cuda(), volatile=True) # doesnt save gradients | |
inp_var = inp_var.unsqueeze(0) # batch size is 1 | |
pred = model.netG(inp_var) | |
return util.tensor2im(pred.data) | |
while True: | |
msg = socket.recv() | |
img = Image.frombytes('RGB',(480,320),msg).resize((256, 256), Image.BICUBIC) | |
result = forward(model, img) | |
im = Image.fromarray(result).resize((480,320)) | |
im.save('is_art2.png') | |
time.sleep (0.5) | |
socket.send(np.array(im)) | |
time.sleep(1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment