Skip to content

Instantly share code, notes, and snippets.

@heiner
Last active April 23, 2019 16:42
Show Gist options
  • Save heiner/941530eab882f373ac5bf2f5a23aad6b to your computer and use it in GitHub Desktop.
Save heiner/941530eab882f373ac5bf2f5a23aad6b to your computer and use it in GitHub Desktop.
PyTorch CUDA Overlap Data Transfers
import itertools
import logging
import os
import sys
import threading
import time
import timeit
import torch
from torch import nn
from torch.nn import functional as F
logging.basicConfig(
format=('[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
'%(message)s'),
level=0)
num_threads = int(sys.argv[1]) if len(sys.argv) > 1 else 1
batch_size = 60
total_steps = 1000000
class Net(nn.Module):
def __init__(self, num_actions=6):
super(Net, self).__init__()
self.policy = nn.Linear(50, num_actions)
self.baseline = nn.Linear(50, 1)
self.feat_extract = nn.Sequential(
nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(3136, 512), # Could do with less hardcoding.
nn.ReLU(),
)
core_output_size = self.fc[0].out_features
self.policy = nn.Linear(core_output_size, num_actions)
self.baseline = nn.Linear(core_output_size, 1)
def forward(self, x):
x = x.float() / 255.0
x = self.feat_extract(x)
x = torch.flatten(x, 1)
x = self.fc(x)
policy_logits = self.policy(x)
baseline = self.baseline(x)
action = torch.multinomial(
F.softmax(policy_logits, dim=1), num_samples=1)
return action, policy_logits, baseline
class IteratorQueue:
def __init__(self, frame):
self.frame = frame
def __iter__(self):
return self
def __next__(self):
return self.frame
def main():
filename = 'speed_test.json'
with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof:
run()
if prof:
logging.info('Collecting trace and writing to \'%s.gz\'', filename)
if prof:
prof.export_chrome_trace(filename)
os.system('gzip %s' % filename)
def run():
shape = (batch_size, 4, 84, 84)
frame = torch.randint(255, shape, dtype=torch.float32)
queue = IteratorQueue(frame)
should_stop = threading.Event()
cpu_device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
logging.info('Not using CUDA.')
device = cpu_device
model = Net(num_actions=6)
model = model.to(device=device)
step = 0
def target(lock=threading.Lock()):
nonlocal step
with torch.no_grad():
for input in queue:
input = input.pin_memory()
input = input.to(device, non_blocking=True)
output = model(input)
output = [t.cpu() for t in output]
with lock:
step += batch_size
if should_stop.is_set():
break
def direct_target(lock=threading.Lock()):
nonlocal step
input = next(queue)
input = input.to(device)
with torch.no_grad():
while not should_stop.is_set():
output = model(input)
torch.cuda.synchronize()
with lock:
step += batch_size
if should_stop.is_set():
break
def stream_target(lock=threading.Lock()):
nonlocal step
stream = torch.cuda.Stream()
with torch.no_grad():
with torch.cuda.stream(stream):
for input in queue:
input = input.pin_memory()
input = input.to(device, non_blocking=True)
output = model(input)
output = [t.to(cpu_device, non_blocking=True)
for t in output]
stream.synchronize()
with lock:
step += batch_size
if should_stop.is_set():
break
threads = [threading.Thread(target=stream_target)
for _ in range(num_threads)]
for thread in threads:
thread.start()
try:
while step < total_steps:
start_time = timeit.default_timer()
start_step = step
time.sleep(3)
end_step = step
logging.info(
'Step %i @ %.1f SPS.',
end_step, (end_step - start_step) / (
timeit.default_timer() - start_time))
except KeyboardInterrupt:
pass
should_stop.set()
for thread in threads:
thread.join()
if torch.cuda.is_available():
torch.cuda.cudart().cudaProfilerStop()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment