Skip to content

Instantly share code, notes, and snippets.

@aman-tiwari
Created April 12, 2018 22:37
Show Gist options
  • Save aman-tiwari/d187819d17a2ad15374da0ab2de60e93 to your computer and use it in GitHub Desktop.
Save aman-tiwari/d187819d17a2ad15374da0ab2de60e93 to your computer and use it in GitHub Desktop.
zmq ml template
import zmq
def zmq_serv():
bind_at = "tcp://*:5566"
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(bind_at)
print("server init")
img_shape = (480, 720, 3)
input_shape = (480, 720, 4)
model = load_model(img_shape)
print("waiting for recv")
dummy_alpha = np.ones((1, img_shape[0], img_shape[1], 1), dtype=np.uint8)
result = None
sent_already = False
while True:
md = dict(dtype = np.uint8,
shape=input_shape)
recv = socket.recv(copy=False)
try:
buf = memoryview(recv)
t = time.time()
arr = np.frombuffer(buf, dtype=md['dtype'])
arr = arr.reshape(md['shape'])
if result is not None:
t2 = time.time()
socket.send(result, copy=False)
send_time = time.time() - t2
print("sending took: ", send_time)
sent_already = True
arr = arr[:, :, :3]
if arr.shape != img_shape:
arr = imresize(arr, img_shape[:2])
arr = arr[np.newaxis,:,:,:]
arr = arr.astype(np.float32)
result = model.forward(arr)
result = result.clip(0, 255)
result = result.astype(np.uint8)
result = np.concatenate((result, dummy_alpha), axis=3)
inference_time = time.time() - t
print("inference took:", time.time() - t)
if not sent_already:
t2 = time.time()
socket.send(result, copy=False)
send_time = time.time() - t2
print("sending took: ", send_time)
print("theoretical fps: ", 1.0 / (send_time + inference_time))
except Exception as e:
socket.send(b"")
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment