-
-
Save vfdev-5/ce60680119b5d867167b420714da8944 to your computer and use it in GitHub Desktop.
Benchmarking ignite master branch vs metrics_impl on metrics.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
To run the CPU benchmark: `CUDA_VISIBLE_DEVICES="" python benchmark.py --name cpu` | |
To run the GPU benchmark: `CUDA_VISIBLE_DEVICES=0 python benchmark.py --name cuda` | |
To run the distributed benchmark: `python -u -m torch.distributed.launch --nproc_per_node=2 --use_env benchmark.py --name dist` | |
''' | |
import argparse | |
import time | |
import math | |
from functools import partial | |
import torch | |
from ignite import metrics | |
from ignite.engine import Engine | |
import ignite.distributed as idist | |
import pandas as pd | |
rank = idist.get_rank() | |
device = idist.device() | |
def ci(vals): | |
"""Computes 95% confidence interval for mean of vals.""" | |
std, mean = torch.std_mean(vals.detach().cpu()) | |
margin_of_err = 1.96 * std.item() / math.sqrt(vals.size(0)) | |
return mean.item(), margin_of_err | |
def benchmark(f, n_tests=1000): | |
times = torch.zeros(n_tests) | |
for i in range(n_tests): | |
start = time.time() | |
f() | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
times[i] = time.time() - start | |
return ci(times) | |
def get_accuracy_data(offset, n_classes=10): | |
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) | |
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) | |
return y_true, y_preds | |
def get_loss_data(offset, n_classes=10): | |
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) | |
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) | |
return y_true, y_preds | |
def get_mae_data(offset): | |
y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device) | |
y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device) | |
return y_true, y_preds | |
def get_mpd_data(offset): | |
y_true = torch.rand(offset * idist.get_world_size(), 10).to(device) | |
y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device) | |
return y_true, y_preds | |
def get_mse_data(offset): | |
y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device) | |
y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device) | |
return y_true, y_preds | |
def get_multiclass_pr_data(offset, n_classes=10): | |
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) | |
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) | |
return y_true, y_preds | |
def get_multilabel_pr_data(offset, n_classes=10): | |
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) | |
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device) | |
return y_true, y_preds | |
def get_topk_data(offset, n_classes=10): | |
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) | |
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) | |
return y_true, y_preds | |
benchmarks = ( | |
('acc', metrics.Accuracy, get_accuracy_data), | |
('loss', partial(metrics.Loss, torch.nn.NLLLoss()), get_loss_data), | |
('mae', metrics.MeanAbsoluteError, get_mae_data), | |
('mpd', metrics.MeanPairwiseDistance, get_mpd_data), | |
('mse', metrics.MeanSquaredError, get_mse_data), | |
('prec_multiclass_avg', partial(metrics.Precision, average=True), get_multiclass_pr_data), | |
('prec_multiclass', partial(metrics.Precision, average=False), get_multiclass_pr_data), | |
('prec_multilabel_avg', partial(metrics.Precision, average=True, is_multilabel=True), get_multilabel_pr_data), | |
('rec_multiclass_avg', partial(metrics.Recall, average=True), get_multiclass_pr_data), | |
('rec_multiclass', partial(metrics.Recall, average=False), get_multiclass_pr_data), | |
('rec_multilabel_avg', partial(metrics.Recall, average=True, is_multilabel=True), get_multilabel_pr_data), | |
('topk_acc', partial(metrics.TopKCategoricalAccuracy, k=5), get_topk_data), | |
) | |
def main(s=50, n_iters=100, n_epochs=3, n_tests=1000, run_name=''): | |
offset = n_iters * s | |
if rank == 0: | |
print(f'Device: {idist.device()}') | |
devices = [('cpu', 'cpu')] | |
if torch.cuda.is_available(): | |
devices += [('cuda', idist.device())] | |
series_list = [] | |
for name, metric_cls, get_data in benchmarks: | |
y_true, y_preds = get_data(offset=offset) | |
def update(engine, i): | |
return (y_preds[i * s + rank * offset : (i + 1) * s + rank * offset], | |
y_true[i * s + rank * offset : (i + 1) * s + rank * offset]) | |
for metric_device_name, metric_device in devices: | |
engine = Engine(update) | |
metric_cls(device=metric_device).attach(engine, name) | |
data = list(range(n_iters)) | |
mean, margin = benchmark(lambda: engine.run(data=data, max_epochs=n_epochs), n_tests=n_tests) | |
series_list.append([name, run_name, metric_device_name, mean, margin]) | |
df = pd.DataFrame(series_list, columns=['Metric Name', 'Run Name', 'Metric Device', 'Time Mean', 'Time Margin of Error']) | |
df.to_csv(f'{run_name}.csv') | |
print(df.to_markdown()) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--size', type=int, default=50) | |
parser.add_argument('--n_iters', type=int, default=100) | |
parser.add_argument('--n_epochs', type=int, default=3) | |
parser.add_argument('--n_tests', type=int, default=200) | |
parser.add_argument('--name', type=str, default='default') | |
parser.add_argument('--seed', type=int, default=42) | |
args = parser.parse_args() | |
torch.manual_seed(args.seed) | |
main(s=args.size, n_iters=args.n_iters, n_epochs=args.n_epochs, n_tests=args.n_tests, run_name=args.name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment