Created
March 8, 2023 13:55
-
-
Save iiSeymour/bc3fac8280e0259bc22c922d8631a84c to your computer and use it in GitHub Desktop.
LSTM.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
torch.backends.cudnn.benchmark = True | |
BATCH_SIZE = 64 | |
LAYER_SIZE = 384 | |
TIME_SERIES_LEN = 1600 | |
NUM_LAYERS = 5 | |
WARMUP_ROUNDS = 5 | |
BENCHMARK_ROUNDS = 10 | |
NUM_STREAMS = 5 | |
def calculate_flops(batch_size, num_features, num_layers, time_series_len, num_gates): | |
num_weight_matrices = 2 # hidden-hidden weights as well as the input weights | |
return time_series_len * batch_size * num_features * num_features * num_gates * num_weight_matrices * 2 | |
streams = [torch.cuda.Stream() for _ in range (NUM_STREAMS)] | |
# Instantiate lstms | |
lstms = [torch.nn.LSTM(LAYER_SIZE, LAYER_SIZE, NUM_LAYERS, bias=False, batch_first=False).cuda().half().eval() for _ in range(NUM_STREAMS)] | |
# Create some input data: | |
datas = [torch.rand(TIME_SERIES_LEN, BATCH_SIZE, LAYER_SIZE).cuda().half() for _ in range(NUM_STREAMS)] | |
#Warmup | |
for i in range(WARMUP_ROUNDS): | |
for s in range(NUM_STREAMS): | |
lstms[s](datas[s]) | |
#benchmark | |
torch.cuda.synchronize() | |
t0 = time.time() | |
for s in range(NUM_STREAMS): | |
for i in range(BENCHMARK_ROUNDS): | |
with torch.cuda.stream(streams[s]): | |
lstms[s](datas[s]) | |
torch.cuda.synchronize() | |
tf = time.time() | |
t_total = tf - t0 | |
num_samples = TIME_SERIES_LEN * BATCH_SIZE * BENCHMARK_ROUNDS * NUM_STREAMS # at ONT, we use "sample" to mean "time point" | |
flops = calculate_flops(BATCH_SIZE, LAYER_SIZE, NUM_LAYERS, TIME_SERIES_LEN, 4) | |
TFLOPS = (flops * BENCHMARK_ROUNDS * NUM_STREAMS) / t_total / 1e12 | |
v100_peak_tflops = 125 | |
print("Took", t_total, "seconds") | |
print("MSample/s =", num_samples/t_total/1e6) | |
print("TFLOPS = ", TFLOPS) | |
print("% peak teoretical (V100) = ", TFLOPS/v100_peak_tflops * 100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment