Last active
July 5, 2018 15:28
-
-
Save stsievert/1a93cf732d66f22a3080fcd0729364d6 to your computer and use it in GitHub Desktop.
PyTorch and dask.distributed latency
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import socket | |
import sys | |
import time | |
# Create a TCP/IP socket | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# Connect the socket to the port where the server is listening | |
server_address = ('localhost', 10000) | |
print('connecting to {} port {}'.format(*server_address)) | |
sock.connect(server_address) | |
N = 1000 | |
start = time.time() | |
for i in range(N): | |
# Send data | |
message = b'1' | |
sock.sendall(message) | |
# Look for the response | |
amount_received = 0 | |
amount_expected = len(message) | |
while amount_received < amount_expected: | |
data = sock.recv(16) | |
amount_received += len(data) | |
end = time.time() | |
latency = (end - start) / (2*N) | |
print("latency: {ms:0.3f}ms".format(ms=latency / 1e-3)) | |
print('closing socket') | |
sock.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import socket | |
import sys | |
# Create a TCP/IP socket | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
# Bind the socket to the port | |
server_address = ('localhost', 10000) | |
print('starting up on {} port {}'.format(*server_address)) | |
sock.bind(server_address) | |
# Listen for incoming connections | |
sock.listen(1) | |
while True: | |
# Wait for a connection | |
print('waiting for a connection') | |
connection, client_address = sock.accept() | |
try: | |
print('connection from', client_address) | |
# Receive the data in small chunks and retransmit it | |
while True: | |
data = connection.recv(16) | |
if data: | |
connection.sendall(data) | |
else: | |
break | |
finally: | |
# Clean up the connection | |
print("breaking connection") | |
connection.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from distributed import Client | |
from time import time, sleep | |
from distributed import Pub, Sub | |
import distributed | |
import os | |
import torch | |
import torch.distributed as dist | |
from torch.multiprocessing import Process, Manager | |
import sys | |
print("torch version =", torch.__version__) | |
print("dask.distribtued version =", distributed.__version__) | |
x = sys.version_info | |
print("python version =", ".".join([str(x.major), str(x.minor), str(x.micro)])) | |
# torch version = 0.4.0 | |
# dask.distribtued version = 1.21.8+40.g93c6c112 | |
# python version = 3.6.5 | |
def test_latency(c): | |
""" | |
This tests how quickly we can move messages back and forth | |
This is mostly a test of latency. | |
Interestingly this runs 10x slower on Python 2 | |
""" | |
def pingpong(a, b, start=False, n=1000, msg=1): | |
sub = Sub(a) | |
pub = Pub(b) | |
while not pub.subscribers: | |
sleep(0.01) | |
if start: | |
pub.put(msg) # other sub may not have started yet | |
for i in range(n): | |
msg = next(sub) | |
pub.put(msg) | |
return n | |
x = np.random.random(1).astype('float32') | |
n = 1000 | |
x = c.submit(pingpong, 'a', 'b', start=True, msg=x, n=n) | |
y = c.submit(pingpong, 'b', 'a', n=n) | |
start = time() | |
# yield c.gather([x, y]) | |
c.gather([x, y]) | |
stop = time() | |
# Divide by 2*n because 2*n messages sent | |
return (stop - start) / (2*n) | |
from distributed import LocalCluster | |
cluster = LocalCluster(n_workers=2, threads_per_worker=1) | |
client = Client(cluster) | |
latency = test_latency(client) | |
print("dask.distributed latency = {ms:0.3f} ms".format(ms=latency / 1e-3)) | |
# dask.distributed latency = 1.376 ms | |
def run(rank, size, N=1000): | |
start = time() | |
tensor = torch.ones(1) | |
assert tensor.dtype is torch.float32 | |
for i in range(N): | |
if rank == 0: | |
tensor += 1 | |
dist.send(tensor=tensor, dst=1) | |
elif rank == 1: | |
dist.recv(tensor, src=0) | |
assert i == N - 1 | |
t = time() - start | |
avg_latency = t / N | |
print("PyTorch latency = {ms:0.3f} ms".format(ms=avg_latency / 1e-3)) | |
return avg_latency | |
def init_processes(rank, size, fn, return_dict, backend='tcp'): | |
""" Initialize the distributed environment. """ | |
os.environ['MASTER_ADDR'] = '127.0.0.1' | |
os.environ['MASTER_PORT'] = '29500' | |
dist.init_process_group(backend, rank=rank, world_size=size) | |
sleep(5) | |
return_dict[rank] = fn(rank, size) | |
size = 2 | |
processes = [] | |
manager = Manager() | |
return_dict = manager.dict() | |
for rank in range(size): | |
p = Process(target=init_processes, args=(rank, size, run, return_dict)) | |
p.start() | |
processes.append(p) | |
for proc in processes: | |
p.join() | |
# PyTorch latency = 0.032 ms | |
# PyTorch latency = 0.032 ms |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I use https://github.com/dask/distributed/blob/12ddc080d1d876b0fce6fe4b2e863fe7c7b31543/distributed/tests/test_pubsub.py#L13 for the dask.distributed latency test.