Created
August 16, 2018 15:56
-
-
Save CharlesJQuarra/40751a6301084db2bf35e35c0ccb3369 to your computer and use it in GitHub Desktop.
broadcastable version of `torch.cat`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
"""" | |
behavior: | |
fat_cat([torch.randn(1,7,20), torch.randn(5,1,13)], dim=-1).size() == torch.Size([5, 7, 33]) | |
"""" | |
def axis_repeat(t, dim, times): | |
if t.size()[dim] != 1: | |
raise Exception("dimension {0} of tensor of shape {1} is non-singleton, cannot repeat".format(dim, t.size())) | |
return torch.cat(times * [t], dim) | |
def fat_cat(tensor_list, dim=0): | |
shapes = [] | |
for t in tensor_list: | |
shapes.append(list(t.size())) | |
shape_mat = torch.tensor(shapes).transpose(0,1) | |
reshaped_tensors = tensor_list | |
nb_dims = shape_mat.size()[0] | |
for d in range(nb_dims): | |
if d == dim % nb_dims: | |
continue | |
tensor_dims = shape_mat[d] | |
non_singleton_dim = None | |
first_nonsingleton_tensor = None | |
singleton_dims = [] | |
for t in range(tensor_dims.size()[0]): | |
tensor_dim = tensor_dims[t].item() | |
if tensor_dim != 1: | |
if non_singleton_dim is None: | |
non_singleton_dim = tensor_dim | |
first_nonsingleton_tensor = t | |
if non_singleton_dim != tensor_dim: | |
raise Exception("dimension {0} of {1}th tensor of shape {2} does not match non-singleton dimension of {3}th tensor of shape {4}".format(d, t, tensor_list[t].size(), first_nonsingleton_tensor, tensor_list[first_nonsingleton_tensor].size())) | |
else: | |
singleton_dims.append(t) | |
if non_singleton_dim is not None: | |
for sd in singleton_dims: | |
def reshape_tensor_idx(i_, rt_): | |
if i_ == sd: | |
return axis_repeat(rt_, d, non_singleton_dim) | |
return rt_ | |
reshaped_tensors = [reshape_tensor_idx(idx, rt) for idx, rt in enumerate(reshaped_tensors)] | |
return torch.cat(reshaped_tensors, dim) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment