Last active
June 4, 2021 10:16
-
-
Save SomeoneSerge/cf82a0be18c9bc87995120874843a4cc to your computer and use it in GitHub Desktop.
An attempt to evaluate multiple pytorch models in parallel on a single GPU without threading or mpi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from copy import deepcopy | |
from functools import partial | |
from itertools import cycle | |
from pprint import pprint | |
from timeit import timeit | |
from typing import List, Tuple | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchvision | |
from tqdm import tqdm | |
@torch.no_grad() | |
def main( | |
n_workers, n_proposals, template_model, device, collect_fq, collect_all, jit=True | |
): | |
x = torch.randn(1, 3, 224, 224).to(device) | |
workers = [template_model] | |
if jit: | |
workers = [torch.jit.trace(workers[0], x, check_trace=False)] | |
for _ in range(n_workers - 1): | |
workers.append(deepcopy(workers[0])) | |
computations = [[] for _ in workers] | |
proposals = [workers[0].state_dict() for _ in range(n_proposals)] | |
torch.cuda.synchronize() | |
streams = [torch.cuda.default_stream(device=device)] + [ | |
torch.cuda.Stream(device=device) for _ in workers[:-1] | |
] | |
results = {} | |
pbar = tqdm( | |
total=len(proposals), | |
mininterval=0.4, | |
smoothing=0.75, | |
desc=f"{n_proposals=} {n_workers=} {collect_all=} {collect_fq=}", | |
) | |
def check_computation(computations: List[Tuple[int, torch.Tensor]]): | |
while len(computations) > 0: | |
job_id, output = computations.pop() | |
results[job_id] = output.item() | |
pbar.update(1) | |
def eval_proposal(job_id, proposal, computations, worker, stream): | |
with torch.cuda.stream(stream): | |
worker.load_state_dict(proposal) | |
computations.append((job_id, worker(x).sum())) | |
for ((i, proposal), (stream_comps, worker, stream)) in zip( | |
enumerate(proposals), | |
cycle( | |
list(zip(computations, workers, streams)), | |
), | |
): | |
eval_proposal(i, proposal, stream_comps, worker, stream) | |
if not any([i % collect_fq == 0, i % n_workers == 0]): | |
continue | |
if collect_all: | |
for c in computations: | |
check_computation(c) | |
else: | |
check_computation(stream_comps) | |
for comp in computations: | |
check_computation(comp) | |
pbar.close() | |
assert len(results) == len( | |
proposals | |
), f"Only finished {len(results)} / {len(proposals)}" | |
def search_space(): | |
n_proposals = 300 | |
for n_workers in reversed(range(1, 31)): | |
for collect_all in [True, False]: | |
for collect_fq in np.unique([1, max(1, n_workers // 2), n_workers]): | |
if collect_fq == 1 and not collect_all: | |
continue | |
yield { | |
"n_proposals": n_proposals, | |
"n_workers": n_workers, | |
"collect_all": collect_all, | |
"collect_fq": collect_fq, | |
} | |
if __name__ == "__main__": | |
measurements = [] | |
device = torch.device("cuda") | |
template_model = torchvision.models.resnet152( | |
pretrained=False).to(device).eval() | |
for cfg in search_space(): | |
seconds = timeit( | |
partial( | |
main, | |
**cfg, | |
template_model=template_model, | |
device=device, | |
), | |
number=2, | |
) | |
measurements.append( | |
{ | |
**cfg, | |
"seconds_total": seconds, | |
"seconds_per_proposal": seconds / cfg["n_proposals"], | |
} | |
) | |
pd.DataFrame(measurements).sort_values(by="seconds_per_proposal").to_csv( | |
"check_par.csv" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment