Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active July 5, 2018 15:28
Show Gist options
  • Save stsievert/1a93cf732d66f22a3080fcd0729364d6 to your computer and use it in GitHub Desktop.
Save stsievert/1a93cf732d66f22a3080fcd0729364d6 to your computer and use it in GitHub Desktop.
PyTorch and dask.distributed latency
from __future__ import print_function
import socket
import sys
import time
# Create a TCP/IP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Connect the socket to the port where the server is listening
server_address = ('localhost', 10000)
print('connecting to {} port {}'.format(*server_address))
sock.connect(server_address)
N = 1000
start = time.time()
for i in range(N):
# Send data
message = b'1'
sock.sendall(message)
# Look for the response
amount_received = 0
amount_expected = len(message)
while amount_received < amount_expected:
data = sock.recv(16)
amount_received += len(data)
end = time.time()
latency = (end - start) / (2*N)
print("latency: {ms:0.3f}ms".format(ms=latency / 1e-3))
print('closing socket')
sock.close()
from __future__ import print_function
import socket
import sys
# Create a TCP/IP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Bind the socket to the port
server_address = ('localhost', 10000)
print('starting up on {} port {}'.format(*server_address))
sock.bind(server_address)
# Listen for incoming connections
sock.listen(1)
while True:
# Wait for a connection
print('waiting for a connection')
connection, client_address = sock.accept()
try:
print('connection from', client_address)
# Receive the data in small chunks and retransmit it
while True:
data = connection.recv(16)
if data:
connection.sendall(data)
else:
break
finally:
# Clean up the connection
print("breaking connection")
connection.close()
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np
from distributed import Client
from time import time, sleep
from distributed import Pub, Sub
import distributed
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process, Manager
import sys
print("torch version =", torch.__version__)
print("dask.distribtued version =", distributed.__version__)
x = sys.version_info
print("python version =", ".".join([str(x.major), str(x.minor), str(x.micro)]))
# torch version = 0.4.0
# dask.distribtued version = 1.21.8+40.g93c6c112
# python version = 3.6.5
def test_latency(c):
"""
This tests how quickly we can move messages back and forth
This is mostly a test of latency.
Interestingly this runs 10x slower on Python 2
"""
def pingpong(a, b, start=False, n=1000, msg=1):
sub = Sub(a)
pub = Pub(b)
while not pub.subscribers:
sleep(0.01)
if start:
pub.put(msg) # other sub may not have started yet
for i in range(n):
msg = next(sub)
pub.put(msg)
return n
x = np.random.random(1).astype('float32')
n = 1000
x = c.submit(pingpong, 'a', 'b', start=True, msg=x, n=n)
y = c.submit(pingpong, 'b', 'a', n=n)
start = time()
# yield c.gather([x, y])
c.gather([x, y])
stop = time()
# Divide by 2*n because 2*n messages sent
return (stop - start) / (2*n)
from distributed import LocalCluster
cluster = LocalCluster(n_workers=2, threads_per_worker=1)
client = Client(cluster)
latency = test_latency(client)
print("dask.distributed latency = {ms:0.3f} ms".format(ms=latency / 1e-3))
# dask.distributed latency = 1.376 ms
def run(rank, size, N=1000):
start = time()
tensor = torch.ones(1)
assert tensor.dtype is torch.float32
for i in range(N):
if rank == 0:
tensor += 1
dist.send(tensor=tensor, dst=1)
elif rank == 1:
dist.recv(tensor, src=0)
assert i == N - 1
t = time() - start
avg_latency = t / N
print("PyTorch latency = {ms:0.3f} ms".format(ms=avg_latency / 1e-3))
return avg_latency
def init_processes(rank, size, fn, return_dict, backend='tcp'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
sleep(5)
return_dict[rank] = fn(rank, size)
size = 2
processes = []
manager = Manager()
return_dict = manager.dict()
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run, return_dict))
p.start()
processes.append(p)
for proc in processes:
p.join()
# PyTorch latency = 0.032 ms
# PyTorch latency = 0.032 ms
@stsievert
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment