PyTorch Tech Share (July 06 2020) - Simple PyTorch distributed computation functionality testing with pytest-xdist
.
It is about one of many other approaches on how we can test a custom distributed computation functionality by emulating multiple processes.
- Communications between N application's processes
- send/receive tensors
WORLD_SIZE
: number of processes used in computations, e.g. number of all GPUs across all machines- For example, we have 2 machines with 4 GPUs each,
WORLD_SIZE
is2 * 4
- For example, we have 2 machines with 4 GPUs each,
RANK
: process unique identifier, varies between0
toWORLD_SIZE - 1
- local rank : machine-wise process identifier, e.g. GPU index in the node.
Deep learning applications:
- Metric's computation (e.g. accuracy) in distributed setting
- Helper tools to extend current pytorch functionality
- Yet Another Distributed training framework (like
torch.distributed
,horovod
, ...)
Accuracy metric implementation:
import torch.distributed as dist
class Accuracy:
def __init__(self):
self.num_samples = 0
self.num_correct = 0
def update(self, y_pred, y):
self.num_samples += y_pred.shape[0]
self.num_correct += (y_pred == y).sum().item()
def compute(self):
# We need to collect `num_correct` and `num_samples` across participating processes
# ...
# dist.all_reduce(tensor_num_correct)
# dist.all_reduce(tensor_num_samples)
# ...
accuracy = self.num_correct / self.num_samples
return accuracy
# test accuracy
@pytest.fixture()
def local_rank(worker_id):
""" use a different account in each xdist worker """
import os
if "gw" in worker_id:
lrank = int(worker_id.replace("gw", ""))
elif "master" == worker_id:
lrank = 0
else:
raise RuntimeError("Can not get rank from worker_id={}".format(worker_id))
yield lrank
@pytest.fixture()
def distributed_context(local_rank):
import os
rank = local_rank
world_size = os.environ["WORLD_SIZE"]
yield {
"local_rank": local_rank,
"rank": rank,
"world_size": world_size,
}
def test_accuracy(distributed_context):
# setup y_pred dependent on rank
# setup y dependent on rank
acc = Accuracy()
acc.update(y_pred, y)
assert acc.compute() == true_acc
- Run 4 processes
WORLD_SIZE=4 pytest --dist=each --tx 4*popen//python=python3.7 -vvv tests/test_accuracy.py
Not clear
- They use their runner application
horovodrun -n 2 -H localhost:2 --gloo pytest -v test/test_torch.py