Skip to content

Instantly share code, notes, and snippets.

@alex-petrenko
Created August 9, 2020 10:29
Show Gist options
  • Save alex-petrenko/06f12dfe2b590fa5c776a7573f340f3d to your computer and use it in GitHub Desktop.
Save alex-petrenko/06f12dfe2b590fa5c776a7573f340f3d to your computer and use it in GitHub Desktop.
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, invert_permutation
def _build_pack_info_from_dones(
dones, T: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create the indexing info needed to make the PackedSequence based on the dones.
PackedSequences are PyTorch's way of supporting a single RNN forward
call where each input in the batch can have an arbitrary sequence length
They work as follows: Given the sequences [c], [x, y, z], [a, b],
we generate data [x, a, c, y, b, z] and batch_sizes [3, 2, 1]. The
data is a flattened out version of the input sequences (the ordering in
data is determined by sequence length). batch_sizes tells you that
for each index, how many sequences have a length of (index + 1) or greater.
This method will generate the new index ordering such that you can
construct the data for a PackedSequence from a (N*T, ...) tensor
via x.index_select(0, select_inds)
"""
num_samples = len(dones)
rollout_boundaries = dones.clone().detach()
rollout_boundaries[T - 1 :: T] = 1 # end of each rollout is the boundary
rollout_boundaries = rollout_boundaries.nonzero().squeeze() + 1
rollout_lengths = rollout_boundaries[1:] - rollout_boundaries[:-1]
first_len = rollout_boundaries[0]
rollout_lengths = torch.cat([first_len.unsqueeze(0), rollout_lengths])
rollout_starts_orig = rollout_boundaries - rollout_lengths
# done=True for the last step in the episode, so done flags rolled 1 step to the right will indicate
# first frames in the episodes
is_new_episode = dones.clone().detach().view((-1, T))
is_new_episode = is_new_episode.roll(1, 1)
# roll() is cyclical, so done=True in the last position in the rollout will roll to 0th position
# we want to avoid it here. (note to self: is there a function that does two of these things at once?)
is_new_episode[:, 0] = 0
is_new_episode = is_new_episode.view((-1, ))
lengths, sorted_indices = torch.sort(rollout_lengths, descending=True)
# We will want these on the CPU for torch.unique_consecutive,
# so move now.
cpu_lengths = lengths.to(device="cpu", non_blocking=True)
# We need to keep the original unpermuted rollout_starts, because the permutation is later applied
# internally in the RNN implementation.
# From modules/rnn.py:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
# hx = self.permute_hidden(hx, sorted_indices)
rollout_starts_sorted = rollout_starts_orig.index_select(0, sorted_indices)
select_inds = torch.empty(num_samples, device=dones.device, dtype=torch.int64)
max_length = int(cpu_lengths[0].item())
# batch_sizes is *always* on the CPU
batch_sizes = torch.empty((max_length,), device="cpu", dtype=torch.int64)
offset = 0
prev_len = 0
num_valid_for_length = lengths.size(0)
unique_lengths = torch.unique_consecutive(cpu_lengths)
# Iterate over all unique lengths in reverse as they sorted
# in decreasing order
for i in range(len(unique_lengths) - 1, -1, -1):
valids = lengths[0:num_valid_for_length] > prev_len
num_valid_for_length = int(valids.float().sum().item())
next_len = int(unique_lengths[i])
batch_sizes[prev_len:next_len] = num_valid_for_length
new_inds = (
rollout_starts_sorted[0:num_valid_for_length].view(1, num_valid_for_length)
+ torch.arange(prev_len, next_len, device=rollout_starts_sorted.device).view(next_len - prev_len, 1)
).view(-1)
# for a set of sequences [1, 2, 3], [4, 5], [6, 7], [8]
# these indices will be 1,4,6,8,2,5,7,3
# (all first steps in all trajectories, then all second steps, etc.)
select_inds[offset : offset + new_inds.numel()] = new_inds
offset += new_inds.numel()
prev_len = next_len
# Make sure we have an index for all elements
assert offset == num_samples
assert is_new_episode.shape[0] == num_samples
return rollout_starts_orig, is_new_episode, select_inds, batch_sizes, sorted_indices
def build_rnn_inputs(x, dones, dones_cpu, rnn_states, T: int):
r"""Create a PackedSequence input for an RNN such that each
set of steps that are part of the same episode are all part of
a batch in the PackedSequence.
Use the returned select_inds and build_core_out_from_seq to invert this.
:param x: A (N*T, -1) tensor of the data to build the PackedSequence out of
:param dones: A (N*T) tensor where dones[i] == 1.0 indicates an episode is done
:param dones_cpu: same but a CPU-bound tensor
:param rnn_states: A (N*T, -1) tensor of the rnn_hidden_states
:param T: The length of the rollout
:return: tuple(x_seq, rnn_states, select_inds)
WHERE
x_seq is the PackedSequence version of x to pass to the RNN
rnn_states are the corresponding rnn state
inverted_select_inds can be passed to build_core_out_from_seq so the RNN output can be retrieved
"""
(
rollout_starts,
is_new_episode,
select_inds,
batch_sizes,
sorted_indices,
) = _build_pack_info_from_dones(dones_cpu, T)
inverted_select_inds = invert_permutation(select_inds)
select_inds = select_inds.to(device=x.device)
inverted_select_inds = inverted_select_inds.to(device=x.device)
sorted_indices = sorted_indices.to(device=x.device)
x_seq = PackedSequence(x.index_select(0, select_inds), batch_sizes, sorted_indices)
rollout_starts = rollout_starts.to(device=x.device)
# We zero-out rnn states for timesteps at the beginning of the episode.
# rollout_starts are indices of all starts of sequences
# (which can be due to episode boundary or just boundary of a rollout)
# (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts) gives us a zero for every beginning of
# the sequence that is actually also a start of a new episode, and by multiplying this RNN state by zero
# we ensure no information transfer across episode boundaries.
rnn_states = (rnn_states.index_select(0, rollout_starts) * (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts))
return x_seq, rnn_states, inverted_select_inds
def build_core_out_from_seq(x_seq: PackedSequence, inverted_select_inds):
return x_seq.data.index_select(0, inverted_select_inds)
T = 97
N = 64
D = 128
rnn = nn.GRU(D, D, 1)
total_frames = 0
for _ in range(100):
# dones = torch.randint(0, 2, (N * T,))
rnn_hidden_states_random = torch.rand(T * N, D)
dones = torch.zeros((N * T,) )
for i in range(1, N * T, 7):
dones[i] = 1.0
x = torch.randn(T * N, D)
rnn_hidden_states = rnn_hidden_states_random.clone().detach()
x_seq, seq_states, inverted_select_inds = build_rnn_inputs(
x, dones, dones, rnn_hidden_states, T
)
new_out, _ = rnn(x_seq, seq_states.unsqueeze(0))
new_out = build_core_out_from_seq(new_out, inverted_select_inds)
rnn_hidden_states = rnn_hidden_states_random.clone().detach()
rnn_hidden_states = rnn_hidden_states[::T].unsqueeze(0)
old_outputs = []
for t in range(T):
rnn_out, rnn_hidden_states = rnn(x[t::T].view(1, N, -1), rnn_hidden_states)
old_outputs.append(rnn_out.view(N, -1))
rnn_hidden_states = rnn_hidden_states * (1 - dones[t::T].view(1, N, 1))
old_outputs = torch.stack(old_outputs, dim=1).view(N * T, -1)
print(torch.norm(new_out - old_outputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment