Last active
June 2, 2023 09:27
-
-
Save arquolo/e3aee9ef9b7a9d253c83ef4558bd821a to your computer and use it in GitHub Desktop.
Benchmark speed and batch of neural networks from PyTorch framework. Requires Python 3.8-3.11, PyTorch 2.0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import warnings | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
from collections.abc import Callable | |
from concurrent.futures import ProcessPoolExecutor | |
from functools import partial | |
from itertools import count | |
from time import perf_counter | |
os.environ['CUDA_MODULE_LOADING'] = 'LAZY' | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync' | |
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.multiprocessing as tomp | |
from matplotlib.ticker import EngFormatter | |
from torch import nn | |
from torchvision import models | |
from tqdm.auto import tqdm | |
plt.rcParams['svg.fonttype'] = 'none' | |
tomp.set_sharing_strategy('file_system') | |
warnings.filterwarnings('ignore') | |
MAX_BATCH = 8192 | |
TIMEOUT = 5 # Time in seconds for measuring speed for each run | |
NBITS = 5 # only high X bits of resulting batch size are precise | |
SHAPE = (224, 224) | |
OUTFILE = 'batch_test' | |
NETS: dict[str, Callable[..., nn.Module]] = { | |
name: partial(models.get_model, name, weights=None) | |
for name in models.list_models(module=models) | |
} | |
for _name in ('inception_v3', 'googlenet'): | |
if _name in NETS: | |
NETS[_name] = partial(NETS[_name], aux_logits=False) | |
YV_2_NET = { # (arxiv year, torchvision version) -> model family | |
# 14 | |
(2014, '0.1.6'): ['alexnet', 'vgg'], # 1+8 / 9 | |
(2015, '0.1.6'): ['resnet'], # 5 | |
# 7 | |
(2015, '0.1.8'): ['inception_v3'], # 1 | |
(2016, '0.1.8'): ['densenet', 'squeezenet'], # 4+2 / 6 | |
# 9 | |
(2014, '0.3.0'): ['googlenet'], # 1 | |
(2016, '0.3.0'): ['resnext'], # 3 | |
(2018, '0.3.0'): ['mobilenet_v2', 'shufflenet'], # 1+4 / 5 | |
# 6 | |
(2016, '0.4.0'): ['wide_resnet'], # 2 | |
(2018, '0.4.0'): ['mnasnet'], # 4 | |
# 2 | |
(2019, '0.9.0'): ['mobilenet_v3'], # 2 | |
# 23 | |
(2019, '0.11.0'): ['efficientnet'], # 8 | |
(2020, '0.11.0'): ['regnet'], # 15 | |
# 9 | |
(2020, '0.12.0'): ['vit'], # 5 | |
(2022, '0.12.0'): ['convnext'], # 4 | |
# 6 | |
(2021, '0.13.0'): ['efficientnet_v2', 'swin'], # 3+3 / 6 | |
# 4 | |
(2022, '0.14.0'): ['maxvit', 'swin_v2'], # 1+3 / 4 | |
} | |
NETS = dict(sorted(NETS.items())) | |
def main() -> None: | |
# Find min/max params counts | |
metas = { | |
name: models.get_model_weights(name).DEFAULT.meta for name in NETS | |
} | |
param_counts = {name: meta['num_params'] for name, meta in metas.items()} | |
min_max_params = min(param_counts.values()), max(param_counts.values()) | |
# Parse args | |
p = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) | |
p.add_argument('-d', '--device', type=int, help='device id to use') | |
p.add_argument( | |
'-y', | |
'--year', | |
type=int, | |
choices=sorted(y for y, _ in YV_2_NET), | |
help='release year') | |
p.add_argument('-i', '--show', action='store_true') | |
p.add_argument('-t', '--tag') | |
p.add_argument('--no-fp16', action='store_false', dest='fp16') | |
p.add_argument( | |
'--params', | |
type=float, | |
nargs=2, | |
default=min_max_params, | |
help='model params') | |
args = p.parse_args() | |
min_params, max_params = args.params | |
min_params = max(min_max_params[0], min_params) | |
max_params = min(min_max_params[1], max_params) | |
# Group model by their families | |
print(f'Found {len(NETS)} models to test') | |
yr_nets: dict[int, dict[str, Callable[..., nn.Module]]] = {} | |
for net, net_fn in NETS.items(): | |
if not (min_params <= param_counts[net] <= max_params): | |
continue | |
year = max( | |
(yr for (yr, _), fams in YV_2_NET.items() | |
if any(map(net.startswith, fams))), | |
default=-1, | |
) | |
yr_nets.setdefault(year, {})[net] = net_fn | |
yr_nets = dict(sorted(yr_nets.items())) | |
num_selected = sum(len(ns) for ns in yr_nets.values()) | |
print(f'Selected {num_selected} models:') | |
print( | |
*sorted((param_counts[n], n) for ns in yr_nets.values() for n in ns), | |
sep='\n') | |
# Run | |
print(f'Using {SHAPE} shape for image and {TIMEOUT}s per run') | |
if num_selected < len(NETS): # All to one group | |
yr_nets = {0: {k: v for f in yr_nets.values() for k, v in f.items()}} | |
with tqdm(yr_nets.items(), desc='measuring') as bar: | |
for year, nets in bar: | |
bar.set_postfix(year=year) | |
if args.year not in (None, year): | |
continue | |
outfile = f'{OUTFILE}.{year}-{len(nets)}' | |
if args.tag is not None: | |
outfile = f'{outfile}.{args.tag}' | |
_measure_n_plots(nets, outfile, args, metas) | |
def _measure_n_plots(nets: dict[str, Callable[..., nn.Module]], outfile: str, | |
args, metas): | |
stats = [] | |
for name, net_fn in tqdm(nets.items()): | |
net = net_fn() | |
size = sum(t.numel() for ts in (net.parameters, net.buffers) | |
for t in ts()) / 1e6 | |
in1k_top1 = metas[name]['_metrics']['ImageNet-1K']['acc@1'] | |
in1k_top5 = metas[name]['_metrics']['ImageNet-1K']['acc@5'] | |
name = f'{name}/{size:.1f}M' | |
be16, se16 = _find_max_batch(net, fp16=True, dev=args.device) | |
be32, se32 = _find_max_batch( | |
net, | |
high=be16 or MAX_BATCH, | |
dev=args.device, | |
) | |
bt16, st16 = _find_max_batch( | |
net, | |
is_train=True, | |
fp16=True, | |
high=be16 or MAX_BATCH, | |
dev=args.device, | |
) | |
bt32, st32 = _find_max_batch( | |
net, | |
is_train=True, | |
high=be32, | |
dev=args.device, | |
) | |
stats.append({ | |
'name': name, | |
'train/batch': (bt16, bt32), | |
'train/fps': (st16, st32), | |
'infer/batch': (be16, be32), | |
'infer/fps': (se16, se32), | |
'in1k_err': (round(100 - in1k_top1, 3), round(100 - in1k_top5, 3)), | |
}) | |
tqdm.write(f'{name}: ' | |
f'{bt32}..{be16 or be32} bs, ' | |
f'{st32}...{se16 or se32} fps') | |
df = pd.DataFrame.from_records(stats) | |
df = df.sort_values('in1k_err') | |
# df = df.sort_values('train/batch') | |
df.to_csv(f'{outfile}.csv', index=False) | |
fig, ax = plt.subplots(figsize=(8, 16)) | |
ax.set_xscale('log', base=2) | |
ax.set(xlim=(1, 16384)) | |
ax.xaxis.set_major_formatter(EngFormatter()) | |
ax.grid(True) | |
xs = np.arange(df.shape[0], dtype=float) | |
cmap = plt.get_cmap('tab10') | |
labels = { | |
'train/batch': ['fp16', 'fp32'], | |
'train/fps': ['fp16', 'fp32'], | |
'infer/batch': ['fp16', 'fp32'], | |
'infer/fps': ['fp16', 'fp32'], | |
'in1k_err': ['top@1', 'top@5'], | |
} | |
akwds = { | |
'fontsize': 8, | |
'xytext': (0, 0), | |
'textcoords': 'offset points', | |
'va': 'center', | |
} | |
n = len(labels) | |
h = 1 / (n + 1) | |
offsets = np.linspace(h - .5, .5 - h, n).tolist() | |
for i, (offset, (l_head, l_tails)) in enumerate( | |
zip(offsets, labels.items())): | |
bkwds = {'height': h, 'color': cmap(i)} | |
fp16, fp32 = zip(*df[l_head].values.tolist()) | |
for l_tail, dat, alpha, ha in zip(l_tails, [fp16, fp32], [0.5, None], | |
['left', 'right']): | |
if dat is None: | |
continue | |
label_ = f'{l_head}/{l_tail}' | |
bar = ax.barh(xs - offset, dat, label=label_, alpha=alpha, **bkwds) | |
for r in bar: | |
xy = (r.get_width(), r.get_y() + r.get_height() / 2) | |
ax.annotate(r.get_width(), xy=xy, ha=ha, **akwds) | |
ax.set_yticks(xs) | |
ax.set_yticklabels(df['name'].values.tolist()) | |
ax.legend() | |
fig.tight_layout() | |
fig.savefig(f'{outfile}.svg') | |
if args.show: | |
plt.show() | |
def _find_max_batch(net: nn.Module, | |
is_train: bool = False, | |
fp16: bool = False, | |
low: int = 0, | |
high: int = MAX_BATCH, | |
dev: int | None = None) -> tuple[int, int]: | |
speed = 0. | |
exc_ = None | |
rg = sorted({_ev_round(x, NBITS) for x in range(low, high + 1)}) | |
with tqdm( | |
desc=(('infer', 'train')[is_train] + f'/fp{(32, 16)[fp16]}'), | |
leave=False) as bar: | |
while len(rg) > 2: | |
mid_pos = len(rg) // 2 | |
batch = rg[mid_pos] | |
try: | |
speed_ = ProcessPoolExecutor(1).submit( | |
_test_speed, net, fp16, is_train, batch, dev).result() | |
# Not supported CUDA | |
except _LegacyDeviceError: | |
return 0, 0 | |
# Out of memory. Shrink | |
except ( | |
RuntimeError, | |
torch.cuda.OutOfMemoryError, # type:ignore[misc] | |
) as exc: | |
rg, exc_ = rg[:mid_pos + 1], exc | |
# Not all mem acquired. Grow | |
else: | |
rg, exc_, speed = rg[mid_pos:], None, max(speed, speed_) | |
bar.set_postfix_str(f'range: {rg[0]}..{rg[-1]} @ {len(rg)} items') | |
bar.update() | |
if rg[0] == 0 and exc_ is not None: | |
raise exc_ from None | |
return rg[0], int(speed) | |
def _ev_round(x: int, bits: int = 5): | |
nbits = max(x.bit_length() - bits, 0) | |
return (x >> nbits) << nbits | |
def _test_speed(net: nn.Module, | |
fp16: bool, | |
is_train: bool, | |
batch: int, | |
dev: int | None = None) -> float: | |
devs = list(range(torch.cuda.device_count())) if dev is None else [dev] | |
dprops = *map(torch.cuda.get_device_properties, devs), | |
if len({dp.name for dp in dprops}) > 1: # Some GPUs differ, use first | |
devs, dprops = devs[:1], dprops[:1] | |
if fp16 and dprops[0].major < 7: # Forbid FP16 for pre sm7x GPUs | |
raise _LegacyDeviceError | |
max_mem = sum(dp.total_memory for dp in dprops) | |
if len(devs) > 1: | |
net = torch.nn.DataParallel(net) | |
net.cuda(dev).train(is_train) | |
data = torch.rand(batch, 3, *SHAPE, device=f'cuda:{devs[0]}') | |
do_step = partial(_step, net, data, is_train=is_train, fp16=fp16) | |
with torch.set_grad_enabled(is_train): | |
start = perf_counter() | |
for n in count(batch, batch): | |
do_step() | |
# Forbid VRAM extension via RAM to not halt performance. | |
used_mem = sum( | |
max(s[f'{t}_bytes.all.peak'] | |
for t in ('active', 'allocated', 'reserved')) | |
for s in map(torch.cuda.memory_stats, devs)) | |
if used_mem >= max_mem: | |
raise torch.cuda.OutOfMemoryError # type:ignore[misc] | |
if (done := perf_counter() - start) > TIMEOUT: | |
return n / done | |
return 0. | |
def _step(net: nn.Module, data: torch.Tensor, is_train: bool, fp16: bool): | |
with torch.autocast('cuda', enabled=fp16): | |
loss = net(data).sum() | |
if is_train: | |
net.zero_grad() | |
loss.backward() | |
loss.item() | |
class _LegacyDeviceError(Exception): | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment