Skip to content

Instantly share code, notes, and snippets.

@proger
Created November 8, 2023 15:56
Show Gist options
  • Select an option

  • Save proger/3543f2e1a369ba0a4aaf16a1625272fa to your computer and use it in GitHub Desktop.

Select an option

Save proger/3543f2e1a369ba0a4aaf16a1625272fa to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
def forward_rnn(forget, input, output, hidden, T, x):
outputs = []
for t in range(T):
u = input(x[:, t, :])
hidden = (forget(hidden) + u).relu()
outputs.append(output(hidden))
return torch.stack(outputs, dim=-2)
def forward_rnn1(forget, input, output, hidden, T, x):
outputs = []
for t in range(T):
u = F.linear(x[:, t, :], input)
hidden = (F.linear(hidden, forget) + u).relu()
outputs.append(F.linear(hidden, output))
return torch.stack(outputs, dim=-2)
class RNN(nn.Module):
def __init__(self, dim):
super().__init__()
self.forget = nn.Linear(dim, dim, bias=False)
nn.init.eye_(self.forget.weight)
self.input = nn.Linear(dim, dim, bias=False)
nn.init.normal_(self.input.weight, 0, 0.001)
self.output = nn.Linear(dim, dim, bias=False)
nn.init.normal_(self.output.weight, 0, 0.001)
def forward(self, x):
return forward_rnn(self.forget, self.input, self.output, x)
device = 'cuda:0'
N, T, C = 8, 128, 64
import time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
#forward_rnn = torch.compile(forward_rnn)
for num_layers in range(1, 2):
rnns = nn.ModuleList([
RNN(C)
for layer in range(num_layers)
]).to(device)
dummy_x = torch.randn(N, T, C).to(device)
hidden = dummy_x.new_zeros(N, C)
x = dummy_x
for rnn in rnns:
x = forward_rnn(rnn.forget, rnn.input, rnn.output, hidden, T, x)
print(rnn)
print(x.shape)
end.record()
torch.cuda.synchronize()
print('elapsed slow', start.elapsed_time(end))
rnn = RNN(C).to(device)
#forward_rnn1 = torch.export.export(forward_rnn1, args=(rnn.forget.weight, rnn.input.weight, rnn.output.weight, torch.zeros(N,C).to(device), T, torch.randn(N, T, C).to(device), ))
forward_rnn1 = torch.compile(forward_rnn1)
print('compiled', forward_rnn1)
forward_rnn1(rnn.forget.weight, rnn.input.weight, rnn.output.weight, torch.zeros(N,C).to(device), T, torch.randn(N, T, C).to(device))
print('warmed up', forward_rnn1)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for num_layers in range(1, 2):
rnns = nn.ModuleList([
RNN(C)
for layer in range(num_layers)
]).to(device)
dummy_x = torch.randn(N, T, C).to(device)
hidden = dummy_x.new_zeros(N, C)
x = dummy_x
for rnn in rnns:
x = forward_rnn1(rnn.forget.weight, rnn.input.weight, rnn.output.weight, hidden, T, x)
print(rnn)
print(x.shape)
end.record()
torch.cuda.synchronize()
print('elapsed fast', start.elapsed_time(end))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment