Last active
January 12, 2022 12:04
-
-
Save alsrgv/0713add50fe49a409316832a31612dde to your computer and use it in GitHub Desktop.
Horovod-PyTorch with Apex (look for "# Apex")
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import torch.backends.cudnn as cudnn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data.distributed | |
from torchvision import models | |
import horovod.torch as hvd | |
import timeit | |
import numpy as np | |
# Apex | |
from apex import amp | |
# Benchmark settings | |
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--fp16-allreduce', action='store_true', default=False, | |
help='use fp16 compression during allreduce') | |
parser.add_argument('--model', type=str, default='resnet50', | |
help='model to benchmark') | |
parser.add_argument('--batch-size', type=int, default=32, | |
help='input batch size') | |
parser.add_argument('--num-warmup-batches', type=int, default=10, | |
help='number of warm-up batches that don\'t count towards benchmark') | |
parser.add_argument('--num-batches-per-iter', type=int, default=10, | |
help='number of batches per benchmark iteration') | |
parser.add_argument('--num-iters', type=int, default=10, | |
help='number of benchmark iterations') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
args = parser.parse_args() | |
args.cuda = not args.no_cuda and torch.cuda.is_available() | |
hvd.init() | |
if args.cuda: | |
# Horovod: pin GPU to local rank. | |
torch.cuda.set_device(hvd.local_rank()) | |
cudnn.benchmark = True | |
# Set up standard model. | |
model = getattr(models, args.model)() | |
if args.cuda: | |
# Move model to GPU. | |
model.cuda() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
# Horovod: (optional) compression algorithm. | |
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none | |
# Horovod: wrap optimizer with DistributedOptimizer. | |
optimizer = hvd.DistributedOptimizer(optimizer, | |
named_parameters=model.named_parameters(), | |
compression=compression) | |
# Horovod: broadcast parameters & optimizer state. | |
hvd.broadcast_parameters(model.state_dict(), root_rank=0) | |
hvd.broadcast_optimizer_state(optimizer, root_rank=0) | |
# Apex | |
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") | |
# Set up fixed fake data | |
data = torch.randn(args.batch_size, 3, 224, 224) | |
target = torch.LongTensor(args.batch_size).random_() % 1000 | |
if args.cuda: | |
data, target = data.cuda(), target.cuda() | |
def benchmark_step(): | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.cross_entropy(output, target) | |
# Apex | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.synchronize() | |
with optimizer.skip_synchronize(): | |
optimizer.step() | |
def log(s, nl=True): | |
if hvd.rank() != 0: | |
return | |
print(s, end='\n' if nl else '') | |
log('Model: %s' % args.model) | |
log('Batch size: %d' % args.batch_size) | |
device = 'GPU' if args.cuda else 'CPU' | |
log('Number of %ss: %d' % (device, hvd.size())) | |
# Warm-up | |
log('Running warmup...') | |
timeit.timeit(benchmark_step, number=args.num_warmup_batches) | |
# Benchmark | |
log('Running benchmark...') | |
img_secs = [] | |
for x in range(args.num_iters): | |
time = timeit.timeit(benchmark_step, number=args.num_batches_per_iter) | |
img_sec = args.batch_size * args.num_batches_per_iter / time | |
log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device)) | |
img_secs.append(img_sec) | |
# Results | |
img_sec_mean = np.mean(img_secs) | |
img_sec_conf = 1.96 * np.std(img_secs) | |
log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf)) | |
log('Total img/sec on %d %s(s): %.1f +-%.1f' % | |
(hvd.size(), device, hvd.size() * img_sec_mean, hvd.size() * img_sec_conf)) |
@alsrgv I think you are right, I made some misunderstanding.
there are something wrong if i set --fp16-allreduce
, the error are show in the blow:
Traceback (most recent call last):
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/tsukiko2/test/fp16_hvd.py", line 292, in <module>
timeit.timeit(benchmark_step, number=args.num_warmup_batches)
File "/usr/lib/python3.6/timeit.py", line 233, in timeit
return Timer(stmt, setup, timer, globals).timeit(number)
File "/usr/lib/python3.6/timeit.py", line 178, in timeit
timing = self.inner(it, self.timer)
File "<timeit-src>", line 6, in inner
File "/root/tsukiko2/test/fp16_hvd.py", line 274, in benchmark_step
optimizer.synchronize()
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()
Traceback (most recent call last):
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/tsukiko2/test/fp16_hvd.py", line 292, in <module>
timeit.timeit(benchmark_step, number=args.num_warmup_batches)
File "/usr/lib/python3.6/timeit.py", line 233, in timeit
return Timer(stmt, setup, timer, globals).timeit(number)
File "/usr/lib/python3.6/timeit.py", line 178, in timeit
timing = self.inner(it, self.timer)
File "<timeit-src>", line 6, in inner
File "/root/tsukiko2/test/fp16_hvd.py", line 273, in benchmark_step
scaled_loss.backward()
File "/usr/local/lib/python3.6/dist-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 139, in hook
handle, ctx = self._allreduce_grad_async(p)
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 122, in _allreduce_grad_async
handle = allreduce_async_(tensor_compressed, average=True, name=name)
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/mpi_ops.py", line 176, in allreduce_async_
return _allreduce_async(tensor, tensor, average, name)
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/mpi_ops.py", line 81, in _allreduce_async
name.encode() if name is not None else _NULL)
RuntimeError: Horovod has been shut down. This was caused by an exception on one of the ranks or an attempt to allreduce, allgather or broadcast a tensor after one of the ranks finished execution. If the shutdown was caused by an exceptio
n, you should see the exception in the log before the first shutdown message.
and when i run my own demo i got this error:
from __future__ import division
from __future__ import print_function
import argparse
import time
import os
import sys
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
# Horovod
import horovod.torch as hvd
parser = argparse.ArgumentParser()
parser.add_argument("--apex", action="store_true")
parser.add_argument("--opt_level", type=str, default="O1")
parser.add_argument("--fp16_allreduce", action="store_true")
args = parser.parse_args()
hvd.init()
world_size = hvd.size()
world_rank = hvd.rank()
local_rank = hvd.local_rank()
APEX = args.apex
if APEX:
import apex
if world_rank == 0:
print("use apex")
DETERMINISTIC = True
if DETERMINISTIC:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(4)
torch.cuda.manual_seed_all(4)
torch.set_printoptions(precision=10)
else:
cudnn.benchmark = True
time.sleep(0.1 * world_rank)
print("init [%2s/%2s]" % (world_rank, world_size))
time.sleep(0.1 * (world_size - world_rank))
torch.cuda.set_device(local_rank)
if world_rank == 0:
print("set_device")
assert torch.backends.cudnn.enabled
model = torchvision.models.resnet50()
if world_rank == 0:
print("model")
if APEX:
SYNC_BN = True
if SYNC_BN:
model = apex.parallel.convert_syncbn_model(model)
if world_rank == 0:
print("convert_syncbn_model")
model = model.cuda()
if world_rank == 0:
print("model.cuda")
optimizer = torch.optim.SGD(
model.parameters(), 0.01,
momentum=0.9,
weight_decay=1e-4
)
if world_rank == 0:
print("optimizer")
FP16_ALLREDUCE = args.fp16_allreduce
optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
compression=hvd.Compression.fp16 if FP16_ALLREDUCE else hvd.Compression.none
)
if world_rank == 0:
print("hvd.DistributedOptimizer")
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
if world_rank == 0:
print("hvd.broadcast_parameters")
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
if world_rank == 0:
print("hvd.broadcast_optimizer_state")
if APEX:
OPT_LEVEL = args.opt_level
model, optimizer = apex.amp.initialize(model, optimizer,
opt_level=OPT_LEVEL
)
if world_rank == 0:
print("apex.amp.initialize: %s" % OPT_LEVEL)
criterion = nn.CrossEntropyLoss().cuda()
torch.manual_seed(4)
torch.cuda.manual_seed_all(4)
_inputs = (torch.LongTensor(64, 3, 224, 224).random_().cuda() % 255).float().add_(-127.5).mul_(1/255)
targets = (torch.LongTensor(64).random_().cuda() % 1000)
_time = time.time()
epochs = 12
batches = 5
for epoch_idx in range(epochs):
for batch_idx in range(batches):
if world_rank == 0:
print("\n\nepoch: %s, batch: %s" % (epoch_idx, batch_idx))
seed = epoch_idx * batches + batch_idx
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
inputs = _inputs + (torch.LongTensor(64, 3, 224, 224).random_().cuda() % 255).float().add_(-127.5).mul_(1/25500)
if world_rank == 0:
print(inputs[0, :, 0, 0])
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
if APEX:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
else:
loss.backward()
optimizer.step()
# If tensor requires gradient, then
# tensor.cpu().detach() constructs the .cpu autograd edge, which soon gets destructed since the result is not stored.
# tensor.detach().cpu() does not do this.
# However, this is very fast so virtually they are the same.
if world_rank == 0:
print(targets.detach().cpu().numpy())
print(outputs.detach().cpu().numpy().argmax(axis=1))
if world_rank == 0:
print("time: %s" % (time.time() - _time))
/usr/local/bin/mpirun \
--allow-run-as-root \
-np 2 \
-H localhost:2 \
-bind-to none -map-by slot \
-x LD_LIBRARY_PATH -x PATH -x PYTHONPATH \
-mca pml ob1 -mca btl ^openib \
python -u -m test.fp16_hvd --apex --opt_level O2 --fp16_allreduce
Traceback (most recent call last):
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/tsukiko2/test/fp16_hvd.py", line 169, in <module>
optimizer.synchronize()
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'source'
Traceback (most recent call last):
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/tsukiko2/test/fp16_hvd.py", line 169, in <module>
optimizer.synchronize()
File "/usr/local/lib/python3.6/dist-packages/horovod/torch/__init__.py", line 157, in synchronize
p.grad.set_(self._compression.decompress(output, ctx))
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'source'
Hi, @alsrgv @qingyu-wang
When I set fp16-allreduce to True, the package error is as follows:
<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: optimizer.synchronize()
<stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize
<stderr>: p.grad.set_(self._compression.decompress(output, ctx))
<stderr>:RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach()
How can I solve this problem?
Hi, @alsrgv @qingyu-wang
When I set fp16-allreduce to True, the package error is as follows:<stderr>: optimizer.synchronize() <stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize <stderr>: optimizer.synchronize() <stderr>: File "/usr/local/lib64/python3.6/site-packages/horovod/torch/__init__.py", line 178, in synchronize <stderr>: p.grad.set_(self._compression.decompress(output, ctx)) <stderr>:RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach()
How can I solve this problem?
I also had this problem
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@un-knight, I don't think that's necessary since
np.std
does compute standard deviation according to https://docs.scipy.org/doc/numpy/reference/generated/numpy.std.html. That said, I'm not a statistician and I can be totally wrong :-)