Skip to content

Instantly share code, notes, and snippets.

@SomeoneSerge
Last active June 4, 2021 10:16
Show Gist options
  • Save SomeoneSerge/cf82a0be18c9bc87995120874843a4cc to your computer and use it in GitHub Desktop.
Save SomeoneSerge/cf82a0be18c9bc87995120874843a4cc to your computer and use it in GitHub Desktop.
An attempt to evaluate multiple pytorch models in parallel on a single GPU without threading or mpi
import itertools
from copy import deepcopy
from functools import partial
from itertools import cycle
from pprint import pprint
from timeit import timeit
from typing import List, Tuple
import numpy as np
import pandas as pd
import torch
import torchvision
from tqdm import tqdm
@torch.no_grad()
def main(
n_workers, n_proposals, template_model, device, collect_fq, collect_all, jit=True
):
x = torch.randn(1, 3, 224, 224).to(device)
workers = [template_model]
if jit:
workers = [torch.jit.trace(workers[0], x, check_trace=False)]
for _ in range(n_workers - 1):
workers.append(deepcopy(workers[0]))
computations = [[] for _ in workers]
proposals = [workers[0].state_dict() for _ in range(n_proposals)]
torch.cuda.synchronize()
streams = [torch.cuda.default_stream(device=device)] + [
torch.cuda.Stream(device=device) for _ in workers[:-1]
]
results = {}
pbar = tqdm(
total=len(proposals),
mininterval=0.4,
smoothing=0.75,
desc=f"{n_proposals=} {n_workers=} {collect_all=} {collect_fq=}",
)
def check_computation(computations: List[Tuple[int, torch.Tensor]]):
while len(computations) > 0:
job_id, output = computations.pop()
results[job_id] = output.item()
pbar.update(1)
def eval_proposal(job_id, proposal, computations, worker, stream):
with torch.cuda.stream(stream):
worker.load_state_dict(proposal)
computations.append((job_id, worker(x).sum()))
for ((i, proposal), (stream_comps, worker, stream)) in zip(
enumerate(proposals),
cycle(
list(zip(computations, workers, streams)),
),
):
eval_proposal(i, proposal, stream_comps, worker, stream)
if not any([i % collect_fq == 0, i % n_workers == 0]):
continue
if collect_all:
for c in computations:
check_computation(c)
else:
check_computation(stream_comps)
for comp in computations:
check_computation(comp)
pbar.close()
assert len(results) == len(
proposals
), f"Only finished {len(results)} / {len(proposals)}"
def search_space():
n_proposals = 300
for n_workers in reversed(range(1, 31)):
for collect_all in [True, False]:
for collect_fq in np.unique([1, max(1, n_workers // 2), n_workers]):
if collect_fq == 1 and not collect_all:
continue
yield {
"n_proposals": n_proposals,
"n_workers": n_workers,
"collect_all": collect_all,
"collect_fq": collect_fq,
}
if __name__ == "__main__":
measurements = []
device = torch.device("cuda")
template_model = torchvision.models.resnet152(
pretrained=False).to(device).eval()
for cfg in search_space():
seconds = timeit(
partial(
main,
**cfg,
template_model=template_model,
device=device,
),
number=2,
)
measurements.append(
{
**cfg,
"seconds_total": seconds,
"seconds_per_proposal": seconds / cfg["n_proposals"],
}
)
pd.DataFrame(measurements).sort_values(by="seconds_per_proposal").to_csv(
"check_par.csv"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment