Last active
April 23, 2019 16:42
-
-
Save heiner/941530eab882f373ac5bf2f5a23aad6b to your computer and use it in GitHub Desktop.
PyTorch CUDA Overlap Data Transfers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import logging | |
import os | |
import sys | |
import threading | |
import time | |
import timeit | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
logging.basicConfig( | |
format=('[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' | |
'%(message)s'), | |
level=0) | |
num_threads = int(sys.argv[1]) if len(sys.argv) > 1 else 1 | |
batch_size = 60 | |
total_steps = 1000000 | |
class Net(nn.Module): | |
def __init__(self, num_actions=6): | |
super(Net, self).__init__() | |
self.policy = nn.Linear(50, num_actions) | |
self.baseline = nn.Linear(50, 1) | |
self.feat_extract = nn.Sequential( | |
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, kernel_size=4, stride=2), | |
nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1), | |
nn.ReLU(), | |
) | |
self.fc = nn.Sequential( | |
nn.Linear(3136, 512), # Could do with less hardcoding. | |
nn.ReLU(), | |
) | |
core_output_size = self.fc[0].out_features | |
self.policy = nn.Linear(core_output_size, num_actions) | |
self.baseline = nn.Linear(core_output_size, 1) | |
def forward(self, x): | |
x = x.float() / 255.0 | |
x = self.feat_extract(x) | |
x = torch.flatten(x, 1) | |
x = self.fc(x) | |
policy_logits = self.policy(x) | |
baseline = self.baseline(x) | |
action = torch.multinomial( | |
F.softmax(policy_logits, dim=1), num_samples=1) | |
return action, policy_logits, baseline | |
class IteratorQueue: | |
def __init__(self, frame): | |
self.frame = frame | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return self.frame | |
def main(): | |
filename = 'speed_test.json' | |
with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof: | |
run() | |
if prof: | |
logging.info('Collecting trace and writing to \'%s.gz\'', filename) | |
if prof: | |
prof.export_chrome_trace(filename) | |
os.system('gzip %s' % filename) | |
def run(): | |
shape = (batch_size, 4, 84, 84) | |
frame = torch.randint(255, shape, dtype=torch.float32) | |
queue = IteratorQueue(frame) | |
should_stop = threading.Event() | |
cpu_device = torch.device('cpu') | |
if torch.cuda.is_available(): | |
device = torch.device('cuda:0') | |
else: | |
logging.info('Not using CUDA.') | |
device = cpu_device | |
model = Net(num_actions=6) | |
model = model.to(device=device) | |
step = 0 | |
def target(lock=threading.Lock()): | |
nonlocal step | |
with torch.no_grad(): | |
for input in queue: | |
input = input.pin_memory() | |
input = input.to(device, non_blocking=True) | |
output = model(input) | |
output = [t.cpu() for t in output] | |
with lock: | |
step += batch_size | |
if should_stop.is_set(): | |
break | |
def direct_target(lock=threading.Lock()): | |
nonlocal step | |
input = next(queue) | |
input = input.to(device) | |
with torch.no_grad(): | |
while not should_stop.is_set(): | |
output = model(input) | |
torch.cuda.synchronize() | |
with lock: | |
step += batch_size | |
if should_stop.is_set(): | |
break | |
def stream_target(lock=threading.Lock()): | |
nonlocal step | |
stream = torch.cuda.Stream() | |
with torch.no_grad(): | |
with torch.cuda.stream(stream): | |
for input in queue: | |
input = input.pin_memory() | |
input = input.to(device, non_blocking=True) | |
output = model(input) | |
output = [t.to(cpu_device, non_blocking=True) | |
for t in output] | |
stream.synchronize() | |
with lock: | |
step += batch_size | |
if should_stop.is_set(): | |
break | |
threads = [threading.Thread(target=stream_target) | |
for _ in range(num_threads)] | |
for thread in threads: | |
thread.start() | |
try: | |
while step < total_steps: | |
start_time = timeit.default_timer() | |
start_step = step | |
time.sleep(3) | |
end_step = step | |
logging.info( | |
'Step %i @ %.1f SPS.', | |
end_step, (end_step - start_step) / ( | |
timeit.default_timer() - start_time)) | |
except KeyboardInterrupt: | |
pass | |
should_stop.set() | |
for thread in threads: | |
thread.join() | |
if torch.cuda.is_available(): | |
torch.cuda.cudart().cudaProfilerStop() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment