Created
February 23, 2017 15:10
-
-
Save lebedov/e8f932e3f6bc129adcfcae43d5229d8b to your computer and use it in GitHub Desktop.
Algorithm for concatenating half precision pytorch tensors.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Algorithm for concatenating half precision tensors by allocating new output matrix | |
of appropriate size and copying each of the constituent tensors into it with | |
appropriate offsets. | |
""" | |
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
from torch.cuda import HalfTensor | |
def cat_half(inputs, dimension=0): | |
""" | |
Concatenate half precision tensors along specified dimension. | |
""" | |
# Validate check inputs: | |
assert all([isinstance(x, HalfTensor) for x in inputs]) or \ | |
all([isinstance(x, Variable) and isinstance(x.data, HalfTensor) for x in inputs]) | |
# If the inputs are Variable instances, the output should also be a Variable instance: | |
if isinstance(x, Variable): | |
out_variable = True | |
else: | |
out_variable = False | |
# Create array of tensor dimensions: | |
dims = [] | |
for x in inputs: | |
dims.append(list(x.size())) | |
dims = np.array(dims) | |
# Ensure that the magnitude of the dimensions other than that | |
# along which the tensors will be concatenated are all equal: | |
for i in range(dims.shape[1]): | |
if i != dimension and len(set(dims[:, i])) > 1: | |
raise ValueError('cannot concatenate') | |
# Allocate new tensor whose concatenation dimension is the sum of | |
# the corresponding dimension magnitudes of the inputs: | |
new_dims = dims[0] | |
new_dims[dimension] = sum(dims[:, dimension]) | |
out = HalfTensor(*new_dims) | |
# Copy in the input tensors: | |
offset = 0 | |
for x in inputs: | |
s = [slice(None, None) for i in range(dims.shape[1])] | |
s[dimension] = slice(offset, offset+x.size(dimension)) | |
if isinstance(x, Variable): | |
out[tuple(s)] = x.data | |
else: | |
out[tuple(s)] = x | |
offset += x.size(dimension) | |
if out_variable: | |
out = Variable(out) | |
return out | |
if __name__ == '__main__': | |
a = torch.rand(3, 2) | |
b = torch.rand(2, 2) | |
ab_float = torch.cat((a, b)) | |
a_half = a.cuda(0).half() | |
b_half = b.cuda(0).half() | |
ab_half = cat_half((a_half, b_half)) | |
d = torch.rand(3, 1, 2) | |
e = torch.rand(3, 2, 2) | |
de_float = torch.cat((d, e), 1) | |
d_half = d.cuda(0).half() | |
e_half = e.cuda(0).half() | |
de_half = cat_half((d_half, e_half), 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment