Skip to content

Instantly share code, notes, and snippets.

@kaniblu
Created February 14, 2019 08:03
Show Gist options
  • Save kaniblu/6b7b820da83e565d35c58acc44a9a94d to your computer and use it in GitHub Desktop.
Save kaniblu/6b7b820da83e565d35c58acc44a9a94d to your computer and use it in GitHub Desktop.
Some sparse operators for 2D `torch.sparse` Tensors
import torch.sparse as sp
def sparse_2d_densesum(x, dim=None)
assert len(x.size()) == 2
if dim is None:
return x.values().sum()
values = x.values()
return values.new(x.size(1 - dim)).zero_() \
.scatter_add(0, x.indices()[1 - dim], values)
def sparse_new(x, *args):
torch_type = x.type()
if torch_type == "torch.sparse.FloatTensor":
return sp.FloatTensor(*args)
elif torch_type == "torch.sparse.LongTensor":
return sp.LongTensor(*args)
elif torch_type == "torch.sparse.ByteTensor":
return sp.ByteTensor(*args)
raise TypeError(f"unrecognized torch tensor type: {torch_type}")
def sparse_slice_index(x, dim, indices):
x_idx, x_val, x_size = x.indices(), x.values(), x.size()
mask = (indices.unsqueeze(-1) == x_idx[dim].unsqueeze(0)).sum(0)
_, sort_idx = mask.sort(0, True)
num_items = mask.sum().item()
sort_idx = sort_idx[:num_items]
new_idx = x_idx.index_select(1, sort_idx)
new_val = x_val.index_select(0, sort_idx)
return sparse_new(x, new_idx, new_val, x_size)
def sparse_slice_indices(x, *dim_indices):
assert len(dim_indices) % 2 == 0, \
f"number of dimension-index arguments must be even: {len(dim_indices)}"
dims, indices = zip(*[(dim_indices[i], dim_indices[i + 1])
for i in range(0, len(dim_indices), 2)])
for dim, idx in zip(dims, indices):
x = sparse_slice_index(x, dim, idx).coalesce()
return x
def sparse_2d_getitem(x, a, b):
x_idx, x_val = x.indices(), x.values()
assert len(x_idx) == 2
idxs = ((x_idx[0] == a) & (x_idx[1] == b)).nonzero()
if idxs.size(0) < 1:
return x_val.new(1).zero_()[0]
else:
return x_val[idxs[0, 0]]
def sparse_2d_denseselect(x, dim, idx):
x_idx, x_val = x.indices(), x.values()
sdim = 1 - dim
assert len(x_idx) == 2 and dim in {0, 1} and len(idx) == x.size(sdim)
range = torch.arange(x.size(sdim))
mask = ((x_idx[sdim] == range.unsqueeze(-1)) &
(x_idx[dim] == idx.unsqueeze(-1))).sum(0) > 0
return x_val.new(x.size(sdim)).zero_()\
.scatter_add(0, x_idx[sdim], mask.float() * x_val)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment