Last active
August 7, 2020 17:25
-
-
Save justheuristic/badf4f5ea470e2162af92f0e63be655a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# authors: jheuristic, qwicen | |
import torch | |
import torch.nn as nn | |
import threading | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class TrainableSparseMatrix(nn.Module): | |
def __init__(self, indices, values, **kwargs): | |
""" | |
Sparse matrix as a module with trainable nonzero values | |
and significantly more memory efficient backward | |
""" | |
super().__init__() | |
with torch.no_grad(): | |
template = torch.sparse_coo_tensor(indices, values, requires_grad=False, **kwargs).coalesce() | |
self.register_buffer('template', template) | |
self.values = nn.Parameter(self.template.values().clone(), requires_grad=True) | |
self.template.values().fill_(float('nan')) | |
self.lock = threading.Lock() | |
def forward(self, other): | |
return _TrainableSparseMatrixMul.apply(self.lock, self.template, self.values, other) | |
def mm(self, other): | |
return self(other) | |
@property | |
def shape(self): | |
return self.template.shape | |
@property | |
def dtype(self): | |
return self.template.dtype | |
@property | |
def indices(self): | |
return self.template.indices | |
class _TrainableSparseMatrixMul(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, lock, template, values, other_matrix): | |
with lock, torch.no_grad(): | |
template.values()[:] = values | |
ctx.save_for_backward(values, other_matrix) | |
ctx.__TrainableSparseMatrixMul_metadata = lock, template | |
return torch.sparse.mm(template, other_matrix) | |
@staticmethod | |
@torch.autograd.function.once_differentiable | |
def backward(ctx, grad_sparse_tensor): | |
values, other_matrix = ctx.saved_tensors | |
lock, template = ctx.__TrainableSparseMatrixMul_metadata | |
with lock, torch.no_grad(): | |
template.values()[:] = values | |
grad_other_matrix = torch.sparse.mm(template.t(), grad_sparse_tensor) | |
ii, jj = template.indices() | |
grad_values = torch.sum(grad_sparse_tensor[ii] * other_matrix[jj], dim=-1) | |
# Note: if we run out of memory, you can accumulate this sum in chunks | |
return None, None, grad_values, grad_other_matrix | |
if __name__ == '__main__': | |
print(end='running tests: ') | |
for i in range(100): | |
print(end='.') | |
dim1, dim2, nnz = 10 ** 4, 10 ** 4, 10 ** 6 | |
values = nn.Parameter(torch.randn(nnz), requires_grad=True).to(device) | |
indices = torch.stack([torch.randint(0, dim1, [len(values)]), | |
torch.randint(0, dim2, [len(values)])], dim=0).to(device) | |
our = TrainableSparseMatrix(indices, values, size=(dim1, dim2)) | |
for i in range(5): | |
ref_values = our.values.detach().requires_grad_(True) | |
ref = torch.sparse_coo_tensor(our.template.indices().clone(), ref_values, | |
size=(dim1, dim2), requires_grad=True) | |
if our.values.grad is not None: | |
our.values.grad.zero_() | |
x = torch.randn(our.template.shape[1], 3, device=device, requires_grad=True) | |
y = our.mm(x) + our.mm(x + 1) | |
z = torch.rand_like(y) | |
(y * z).sum().backward() | |
if ref_values.grad is not None: | |
ref_values.grad.zero_() | |
x_clone = x.detach().requires_grad_(True) | |
y = torch.sparse.mm(ref, x_clone) + torch.sparse.mm(ref, x_clone + 1) | |
(y * z).sum().backward() | |
assert torch.allclose(our.values.grad, ref_values.grad, atol=1e-5) | |
assert torch.allclose(x.grad, x_clone.grad, atol=1e-5) | |
print(" passed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment