Created
February 14, 2019 08:03
-
-
Save kaniblu/6b7b820da83e565d35c58acc44a9a94d to your computer and use it in GitHub Desktop.
Some sparse operators for 2D `torch.sparse` Tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.sparse as sp | |
def sparse_2d_densesum(x, dim=None) | |
assert len(x.size()) == 2 | |
if dim is None: | |
return x.values().sum() | |
values = x.values() | |
return values.new(x.size(1 - dim)).zero_() \ | |
.scatter_add(0, x.indices()[1 - dim], values) | |
def sparse_new(x, *args): | |
torch_type = x.type() | |
if torch_type == "torch.sparse.FloatTensor": | |
return sp.FloatTensor(*args) | |
elif torch_type == "torch.sparse.LongTensor": | |
return sp.LongTensor(*args) | |
elif torch_type == "torch.sparse.ByteTensor": | |
return sp.ByteTensor(*args) | |
raise TypeError(f"unrecognized torch tensor type: {torch_type}") | |
def sparse_slice_index(x, dim, indices): | |
x_idx, x_val, x_size = x.indices(), x.values(), x.size() | |
mask = (indices.unsqueeze(-1) == x_idx[dim].unsqueeze(0)).sum(0) | |
_, sort_idx = mask.sort(0, True) | |
num_items = mask.sum().item() | |
sort_idx = sort_idx[:num_items] | |
new_idx = x_idx.index_select(1, sort_idx) | |
new_val = x_val.index_select(0, sort_idx) | |
return sparse_new(x, new_idx, new_val, x_size) | |
def sparse_slice_indices(x, *dim_indices): | |
assert len(dim_indices) % 2 == 0, \ | |
f"number of dimension-index arguments must be even: {len(dim_indices)}" | |
dims, indices = zip(*[(dim_indices[i], dim_indices[i + 1]) | |
for i in range(0, len(dim_indices), 2)]) | |
for dim, idx in zip(dims, indices): | |
x = sparse_slice_index(x, dim, idx).coalesce() | |
return x | |
def sparse_2d_getitem(x, a, b): | |
x_idx, x_val = x.indices(), x.values() | |
assert len(x_idx) == 2 | |
idxs = ((x_idx[0] == a) & (x_idx[1] == b)).nonzero() | |
if idxs.size(0) < 1: | |
return x_val.new(1).zero_()[0] | |
else: | |
return x_val[idxs[0, 0]] | |
def sparse_2d_denseselect(x, dim, idx): | |
x_idx, x_val = x.indices(), x.values() | |
sdim = 1 - dim | |
assert len(x_idx) == 2 and dim in {0, 1} and len(idx) == x.size(sdim) | |
range = torch.arange(x.size(sdim)) | |
mask = ((x_idx[sdim] == range.unsqueeze(-1)) & | |
(x_idx[dim] == idx.unsqueeze(-1))).sum(0) > 0 | |
return x_val.new(x.size(sdim)).zero_()\ | |
.scatter_add(0, x_idx[sdim], mask.float() * x_val) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment