Last active
November 10, 2020 02:09
-
-
Save bstriner/e02d62579195eea66f9cbd246f915cee to your computer and use it in GitHub Desktop.
Performance tests for Pytorch LSTMs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A series of speed tests on pytorch LSTMs. | |
- LSTM is fastest (no surprise) | |
- When you have to go timestep-by-timestep, LSTMCell is faster than LSTM | |
- Iterating using chunks is slightly faster than __iter__ or indexing depending on setup | |
**Results** | |
My Ubuntu server: | |
OS: posix, pytorch version: 0.4.0a0+67bbf58 | |
Device: GeForce GTX 1080, CUDA: 9.1.85, CuDNN: 7102 | |
Test setup: (200,32,40)->(200,32,256) | |
GPU Results | |
lstm_model: 6.118471ms forward, 7.881905ms backward | |
lstm_cell_model_iter: 11.778021ms forward, 30.820508ms backward | |
lstm_cell_model_range: 11.619006ms forward, 30.772598ms backward | |
lstm_cell_model_chunk: 11.829958ms forward, 30.854071ms backward | |
lstm_iterative_model: 25.623804ms forward, 65.606091ms backward | |
CPU Results | |
lstm_model: 136.207031ms forward, 185.528455ms backward | |
lstm_cell_model_iter: 119.665515ms forward, 185.308344ms backward | |
lstm_cell_model_range: 119.667651ms forward, 184.688207ms backward | |
lstm_cell_model_chunk: 119.828860ms forward, 186.508534ms backward | |
lstm_iterative_model: 134.790720ms forward, 206.971720ms backward | |
Testing took 57.328832149505615s | |
My Windows laptop: | |
OS: nt, pytorch version: 0.3.0b0+591e73e | |
Device: GeForce GTX 1070, CUDA: 8.0, CuDNN: 6021 | |
Test setup: (200,32,40)->(200,32,256) | |
GPU Results | |
lstm_model: 14.585430ms forward, 30.203342ms backward | |
lstm_cell_model_iter: 51.728686ms forward, 126.096303ms backward | |
lstm_cell_model_range: 51.461510ms forward, 126.767268ms backward | |
lstm_cell_model_chunk: 50.380049ms forward, 123.715715ms backward | |
lstm_iterative_model: 159.778252ms forward, 458.398417ms backward | |
CPU Results | |
lstm_model: 180.477981ms forward, 543.114929ms backward | |
lstm_cell_model_iter: 179.986208ms forward, 540.873568ms backward | |
lstm_cell_model_range: 180.418623ms forward, 540.747839ms backward | |
lstm_cell_model_chunk: 179.434959ms forward, 539.445579ms backward | |
lstm_iterative_model: 212.276415ms forward, 597.369602ms backward | |
Testing took 201.54079699516296s | |
""" | |
import argparse | |
import os | |
import sys | |
import time | |
import timeit | |
import torch | |
import torch.cuda | |
from torch import nn | |
from torch.autograd import Variable | |
def time_speed(args, model, cuda, number, backward=False): | |
def run(): | |
if cuda: | |
out = torch.cuda.FloatTensor() | |
else: | |
out = torch.FloatTensor() | |
x = Variable(torch.randn(args.batch_len, args.batch_size, args.dim_in, out=out)) | |
h = model(x) | |
if backward: | |
h.sum().backward() | |
run() | |
elapsed = 1000. * timeit.timeit(run, number=number) / number | |
return elapsed | |
def lstm_model(args, cuda): | |
lstm = nn.LSTM(args.dim_in, args.dim_out) | |
if cuda: | |
lstm.cuda() | |
def fun(x): | |
h, state = lstm(x) | |
return h | |
return fun | |
def lstm_cell_model_iter(args, cuda): | |
lstm = nn.LSTMCell(args.dim_in, args.dim_out) | |
if cuda: | |
lstm.cuda() | |
def fun(x): | |
n = x.size(1) | |
h0 = Variable(x.data.new(n, args.dim_out).zero_()) | |
state = (h0, h0) | |
hs = [] | |
for i in x: | |
h, state = lstm(i, state) | |
state = (h, state) | |
hs.append(h) | |
hs = torch.stack(hs, dim=0) | |
return hs | |
return fun | |
def lstm_cell_model_chunk(args, cuda): | |
lstm = nn.LSTMCell(args.dim_in, args.dim_out) | |
if cuda: | |
lstm.cuda() | |
def fun(x): | |
n = x.size(1) | |
h0 = Variable(x.data.new(n, args.dim_out).zero_()) | |
state = (h0, h0) | |
hs = [] | |
for i in x.chunk(x.size(0), 0): | |
h, state = lstm(i.squeeze(0), state) | |
state = (h, state) | |
hs.append(h) | |
hs = torch.stack(hs, dim=0) | |
return hs | |
return fun | |
def lstm_cell_model_range(args, cuda): | |
lstm = nn.LSTMCell(args.dim_in, args.dim_out) | |
if cuda: | |
lstm.cuda() | |
def fun(x): | |
n = x.size(1) | |
h0 = Variable(x.data.new(n, args.dim_out).zero_()) | |
state = (h0, h0) | |
hs = [] | |
for i in range(x.size(0)): | |
h, state = lstm(x[i], state) | |
state = (h, state) | |
hs.append(h) | |
hs = torch.stack(hs, dim=0) | |
return hs | |
return fun | |
def lstm_iterative_model(args, cuda): | |
lstm = nn.LSTM(args.dim_in, args.dim_out) | |
if cuda: | |
lstm.cuda() | |
def fun(x): | |
state = None | |
hs = [] | |
for i in x.chunk(x.size(0), 0): | |
h, state = lstm(i, state) | |
hs.append(h) | |
hs = torch.cat(hs, dim=0) | |
return hs | |
return fun | |
def time_speeds(args, cuda, number): | |
def timer(model_fn_name): | |
model_fn = globals()[model_fn_name] | |
stats = [] | |
if not args.no_forward: | |
fwd = time_speed(args, model_fn(args, cuda), cuda, number, backward=False) | |
stats.append("{:.6f}ms forward".format(fwd)) | |
if not args.no_forward: | |
bwd = time_speed(args, model_fn(args, cuda), cuda, number, backward=True) | |
stats.append("{:.6f}ms backward".format(bwd)) | |
print("{}: {}".format(model_fn_name, ", ".join(stats))) | |
if not args.no_lstm: | |
timer('lstm_model') | |
if not args.no_lstm_cell_iter: | |
timer('lstm_cell_model_iter') | |
if not args.no_lstm_cell_range: | |
timer('lstm_cell_model_range') | |
if not args.no_lstm_cell_chunk: | |
timer('lstm_cell_model_chunk') | |
if not args.no_lstm_iterative: | |
timer('lstm_iterative_model') | |
def run(args): | |
print("OS: {}, pytorch version: {}".format(os.name, torch.__version__)) | |
if torch.cuda.is_available(): | |
from torch.backends import cudnn | |
name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
print("Device: {}, CUDA: {}, CuDNN: {}".format(name, cudnn.cuda, cudnn.version())) | |
print("Test setup: ({},{},{})->({},{},{})".format( | |
args.batch_len, args.batch_size, args.dim_in, | |
args.batch_len, args.batch_size, args.dim_out | |
)) | |
starttime = time.time() | |
if (not args.no_gpu) and torch.cuda.is_available(): | |
print("GPU Results") | |
time_speeds(args, cuda=True, number=args.gpu_number) | |
if not args.no_cpu: | |
print("CPU Results") | |
time_speeds(args, cuda=False, number=args.cpu_number) | |
endtime = time.time() | |
elapsed = endtime - starttime | |
print("Testing took {}s".format(elapsed)) | |
def parse_args(argv): | |
parser = argparse.ArgumentParser(description='Pytorch LSTM Speedtest') | |
parser.add_argument('--batch-size', type=int, default=32, metavar='N', help='batch size') | |
parser.add_argument('--batch-len', type=int, default=200, metavar='N', help='batch len') | |
parser.add_argument('--dim-in', type=int, default=40, metavar='N', help='in dim') | |
parser.add_argument('--dim-out', type=int, default=256, metavar='N', help='out dim') | |
parser.add_argument('--cpu-number', type=int, default=20, metavar='N', help='iterations on CPU') | |
parser.add_argument('--gpu-number', type=int, default=100, metavar='N', help='iterations on GPU') | |
parser.add_argument('--no-lstm', action='store_true', help='disable LSTM test') | |
parser.add_argument('--no-lstm-cell-iter', action='store_true', help='disable LSTMCell with iterator test') | |
parser.add_argument('--no-lstm-cell-range', action='store_true', help='disable LSTMCell with slicing test') | |
parser.add_argument('--no-lstm-cell-chunk', action='store_true', help='disable LSTMCell with chunks test') | |
parser.add_argument('--no-lstm-iterative', action='store_true', help='disable LSTM iterative test') | |
parser.add_argument('--no-gpu', action='store_true', help='disable GPU tests') | |
parser.add_argument('--no-cpu', action='store_true', help='disable CPU tests') | |
parser.add_argument('--no-forward', action='store_true', help='disable forward tests') | |
parser.add_argument('--no-backward', action='store_true', help='disable backward tests') | |
args = parser.parse_args(argv) | |
assert not (args.no_forward and args.no_backward) | |
assert not (args.no_gpu and args.no_cpu) | |
return args | |
def main(argv): | |
args = parse_args(argv) | |
run(args) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Speed tests for pytorch LSTMs