Last active
March 1, 2023 10:55
-
-
Save piojanu/1cd488bc1ec74059e7dc09449c85de33 to your computer and use it in GitHub Desktop.
Wrapper around a PyTorch sparse tensor that allows quick access to the values at the first dimension.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from typing import Any, Sequence, Tuple, Union | |
import torch as th | |
HANDLED_FUNCTIONS = {} | |
# TODO: Properly type annotate. | |
def implements(torch_function: Any) -> Any: | |
"""Register a torch function override for QuickAccessSparseTensor""" | |
@functools.wraps(torch_function) | |
# TODO: Properly type annotate. | |
def decorator(func: Any) -> Any: | |
HANDLED_FUNCTIONS[torch_function] = func | |
return func | |
return decorator | |
@th.jit.script | |
def _index_select_torchscript( | |
indices: th.Tensor, values: th.Tensor, index: th.Tensor, bounds: th.Tensor | |
) -> Tuple[th.Tensor, th.Tensor]: | |
indices_list = [] | |
values_list = [] | |
for output_idx, input_idx in enumerate(index): | |
start, end = bounds[input_idx], bounds[input_idx + 1] | |
indices_ = indices[:, start:end] | |
indices_[0, :] = output_idx | |
indices_list.append(indices_) | |
values_list.append(values[start:end]) | |
return ( | |
th.cat(indices_list, dim=1) if indices_list else th.empty((2, 0), dtype=th.long).to(indices), | |
th.cat(values_list, dim=0) if values_list else th.empty((0,), dtype=values.dtype).to(values), | |
) | |
class QuickAccessSparseTensor: | |
"""Wrapper around a sparse tensor that allows quick access to the values at the first dimension. | |
ASSUMPTIONS: | |
1. The indices are coalesced. If you create it from `torch.sparse_coo_tensor` and then only use operations | |
implemented here, it will remain coalesced. | |
2. It doesn't support indexing other dimensions than the first one. | |
""" | |
def __init__(self, indices: th.LongTensor, values: th.Tensor, shape: th.Size): | |
self.indices = indices | |
self.values = values | |
self.shape = shape | |
# Check if the indices are sorted by the first dimension and sort them if not. | |
if not self.indices[0, 1:].ge(self.indices[0, :-1]).all(): | |
self._sort_indices_by_the_first_dimension() | |
self._bounds = self._calculate_bounds() | |
def _sort_indices_by_the_first_dimension(self) -> None: | |
sorting_by_first_dim = self.indices[0, :].argsort() | |
self.indices = self.indices[:, sorting_by_first_dim] | |
self.values = self.values[sorting_by_first_dim] | |
def _calculate_bounds(self) -> th.LongTensor: | |
bounds = th.zeros(self.shape[0] + 1, dtype=th.long).to(self.indices) | |
bounds[1:] = th.bincount(self.indices[0], minlength=self.shape[0]) | |
bounds.cumsum_(dim=0) | |
return bounds | |
def index_select(self, dim: int, index: th.LongTensor) -> "QuickAccessSparseTensor": | |
assert dim == 0, "Only dim = 0 is supported" | |
indices, values = _index_select_torchscript(self.indices, self.values, index, self._bounds) | |
return QuickAccessSparseTensor( | |
indices=indices, | |
values=values, | |
shape=th.Size([len(index), *self.shape[1:]]), | |
) | |
def size(self, dim: int = None) -> Union[th.Size, int]: | |
if dim is None: | |
return self.shape | |
else: | |
return self.shape[dim] | |
def to(self, *args: Any, **kwargs: Any) -> "QuickAccessSparseTensor": | |
return QuickAccessSparseTensor( | |
indices=self.indices.to(*args, **kwargs), | |
values=self.values.to(*args, **kwargs), | |
shape=self.shape, | |
) | |
def to_dense(self) -> th.Tensor: | |
tensor = th.zeros(self.shape, dtype=self.values.dtype).to(self.values) | |
tensor[tuple(self.indices)] = self.values | |
return tensor | |
def to_sparse(self) -> th.sparse_coo_tensor: | |
return th.sparse_coo_tensor(self.indices, self.values, self.shape) | |
@classmethod | |
def from_sparse_coo_tensor(cls, sparse_tensor: th.sparse_coo_tensor) -> "QuickAccessSparseTensor": | |
sparse_tensor = sparse_tensor.coalesce() | |
return cls( | |
sparse_tensor.indices(), | |
sparse_tensor.values(), | |
sparse_tensor.size(), | |
) | |
# TODO: Properly type annotate. | |
@classmethod | |
def __torch_function__(cls, func: Any, types: Any, args: Any = (), kwargs: Any = None) -> Any: | |
if kwargs is None: | |
kwargs = {} | |
if func not in HANDLED_FUNCTIONS or not all(issubclass(t, QuickAccessSparseTensor) for t in types): | |
return NotImplementedError() | |
return HANDLED_FUNCTIONS[func](*args, **kwargs) | |
@implements(th.cat) | |
def cat(tensors: Sequence[QuickAccessSparseTensor], dim: int = 0, *, out: th.Tensor = None) -> QuickAccessSparseTensor: | |
assert out is None, "Output tensor is not supported." | |
assert all( | |
tensor.shape[:dim] == tensors[0].shape[:dim] | |
and tensor.shape[dim + 1 :] == tensors[0].shape[dim + 1 :] # noqa: E203 | |
for tensor in tensors | |
), "All tensors must have the same shape except for the dimension to concatenate." | |
# Concatenate indices. | |
indices_list = [] | |
for idx, tensor in enumerate(tensors): | |
indices_ = tensor.indices.clone() | |
# Shift the indices at the concatenating dimension into place in the new tensor. | |
indices_[dim, :] += sum(tensor.shape[dim] for tensor in tensors[:idx]) | |
indices_list.append(indices_) | |
indices = th.cat(indices_list, dim=1) | |
# Concatenate values. | |
values = th.cat([tensor.values for tensor in tensors], dim=0) | |
# Calculate the new shape. | |
shape = list(tensors[0].shape) | |
shape[dim] = sum(tensor.shape[dim] for tensor in tensors) | |
shape = th.Size(shape) | |
return QuickAccessSparseTensor(indices, values, shape) | |
@implements(th.index_select) | |
def index_select(tensor: QuickAccessSparseTensor, dim: int, index: th.LongTensor) -> QuickAccessSparseTensor: | |
return tensor.index_select(dim, index) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import torch as th | |
from sparse_pytorch_tensor import QuickAccessSparseTensor | |
@pytest.fixture | |
def sparse_tensor() -> th.sparse_coo_tensor: | |
return th.sparse_coo_tensor( | |
indices=[[0, 2, 2, 2], [0, 1, 2, 3]], | |
values=th.ones(4, dtype=bool), | |
size=(3, 4), | |
dtype=bool, | |
) | |
@pytest.mark.parametrize("index", [th.tensor([0]), th.tensor([1]), th.tensor([2]), th.tensor([0, 1, 2])]) | |
def test_quick_access_sparse_tensor_index_select(sparse_tensor: th.sparse_coo_tensor, index: th.LongTensor) -> None: | |
# given | |
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor) | |
# then | |
assert th.equal( | |
th.index_select(qa_sparse_tensor, 0, index).to_dense(), th.index_select(sparse_tensor, 0, index).to_dense() | |
) | |
def test_quick_access_sparse_tensor_index_select_out_of_bounds(sparse_tensor: th.sparse_coo_tensor) -> None: | |
# given | |
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor) | |
# then | |
with pytest.raises(RuntimeError): | |
qa_sparse_tensor.index_select(0, th.tensor([3])) | |
@pytest.mark.parametrize("dim", [0, 1]) | |
def test_quick_access_sparse_tensors_concatenation(sparse_tensor: th.sparse_coo_tensor, dim: int) -> None: | |
# given | |
qa_sparse_tensor = QuickAccessSparseTensor.from_sparse_coo_tensor(sparse_tensor) | |
# when | |
qa_sparse_tensors_concat = th.cat([qa_sparse_tensor, qa_sparse_tensor, qa_sparse_tensor], dim=dim) | |
sparse_tensors_concat = th.cat([sparse_tensor, sparse_tensor, sparse_tensor], dim=dim) | |
# then | |
assert qa_sparse_tensors_concat.shape == sparse_tensors_concat.size() | |
assert th.equal(qa_sparse_tensors_concat.to_dense(), sparse_tensors_concat.to_dense()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment