Skip to content

Instantly share code, notes, and snippets.

@proger
Last active November 7, 2023 08:52
Show Gist options
  • Select an option

  • Save proger/746ceb4aec2e61cfac4a34f1787be976 to your computer and use it in GitHub Desktop.

Select an option

Save proger/746ceb4aec2e61cfac4a34f1787be976 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
def forward_rnn(forget, input, output, hidden, T, x):
outputs = []
for t in range(T):
hidden = (forget(hidden) + input(x[:, t, :])).relu()
outputs.append(output(hidden))
return torch.stack(outputs, dim=-2)
class RNN(nn.Module):
def __init__(self, dim):
super().__init__()
self.forget = nn.Linear(dim, dim, bias=False)
nn.init.eye_(self.forget.weight)
self.input = nn.Linear(dim, dim, bias=False)
nn.init.normal_(self.input.weight, 0, 0.001)
self.output = nn.Linear(dim, dim, bias=False)
nn.init.normal_(self.output.weight, 0, 0.001)
def forward(self, x):
return forward_rnn(self.forget, self.input, self.output, x)
device = 'cuda:0'
N, T, C = 8, 1024, 64
# [2023-11-06 21:54:52,138] [0/0] torch._utils_internal: [INFO] CompilationMetrics(frame_key='1', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=0, guard_count=13, graph_op_count=6146, graph_node_count=6148, graph_input_count=1, entire_frame_compile_time_s=132.50377583503723, backend_compile_time_s=120.79096937179565, fail_reason=None, non_compliant_ops=set())
# [2023-11-06 21:57:04,859] [0/1] torch._utils_internal: [INFO] CompilationMetrics(frame_key='2', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=1, guard_count=13, graph_op_count=6146, graph_node_count=6148, graph_input_count=1, entire_frame_compile_time_s=132.62526607513428, backend_compile_time_s=120.74623084068298, fail_reason=None, non_compliant_ops=set())
# [2023-11-06 21:59:40,323] [0/2] torch._utils_internal: [INFO] CompilationMetrics(frame_key='3', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=2, guard_count=13, graph_op_count=6146, graph_node_count=6148, graph_input_count=1, entire_frame_compile_time_s=155.36020350456238, backend_compile_time_s=143.30837988853455, fail_reason=None, non_compliant_ops=set())
# [2023-11-06 22:01:53,890] [0/3] torch._utils_internal: [INFO] CompilationMetrics(frame_key='4', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=3, guard_count=13, graph_op_count=6146, graph_node_count=6148, graph_input_count=1, entire_frame_compile_time_s=133.46099162101746, backend_compile_time_s=121.40886402130127, fail_reason=None, non_compliant_ops=set())
# [2023-11-06 22:04:27,993] [0/4] torch._utils_internal: [INFO] CompilationMetrics(frame_key='5', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=4, guard_count=13, graph_op_count=6146, graph_node_count=6148, graph_input_count=1, entire_frame_compile_time_s=153.99701261520386, backend_compile_time_s=141.85603070259094, fail_reason=None, non_compliant_ops=set())
# [2023-11-06 22:05:24,167] [0/5] torch._utils_internal: [INFO] CompilationMetrics(frame_key='6', co_name='forward_rnn', co_filename='/home/proger/lru/compile_rnns2.py', co_firstlineno=5, cache_size=0, accumulated_cache_size=5, guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, fail_reason=None, non_compliant_ops=set())
forward_rnn = torch.compile(forward_rnn)
for num_layers in range(1, 4):
rnns = nn.ModuleList([
RNN(C)
for layer in range(num_layers)
]).to(device)
dummy_x = torch.randn(N, T, C).to(device)
hidden = dummy_x.new_zeros(N, C)
x = dummy_x
for rnn in rnns:
x = forward_rnn(rnn.forget, rnn.input, rnn.output, hidden, T, x)
print(rnn)
print(x.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment