Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active August 7, 2020 17:25
Show Gist options
  • Save justheuristic/badf4f5ea470e2162af92f0e63be655a to your computer and use it in GitHub Desktop.
Save justheuristic/badf4f5ea470e2162af92f0e63be655a to your computer and use it in GitHub Desktop.
# authors: jheuristic, qwicen
import torch
import torch.nn as nn
import threading
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class TrainableSparseMatrix(nn.Module):
def __init__(self, indices, values, **kwargs):
"""
Sparse matrix as a module with trainable nonzero values
and significantly more memory efficient backward
"""
super().__init__()
with torch.no_grad():
template = torch.sparse_coo_tensor(indices, values, requires_grad=False, **kwargs).coalesce()
self.register_buffer('template', template)
self.values = nn.Parameter(self.template.values().clone(), requires_grad=True)
self.template.values().fill_(float('nan'))
self.lock = threading.Lock()
def forward(self, other):
return _TrainableSparseMatrixMul.apply(self.lock, self.template, self.values, other)
def mm(self, other):
return self(other)
@property
def shape(self):
return self.template.shape
@property
def dtype(self):
return self.template.dtype
@property
def indices(self):
return self.template.indices
class _TrainableSparseMatrixMul(torch.autograd.Function):
@staticmethod
def forward(ctx, lock, template, values, other_matrix):
with lock, torch.no_grad():
template.values()[:] = values
ctx.save_for_backward(values, other_matrix)
ctx.__TrainableSparseMatrixMul_metadata = lock, template
return torch.sparse.mm(template, other_matrix)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_sparse_tensor):
values, other_matrix = ctx.saved_tensors
lock, template = ctx.__TrainableSparseMatrixMul_metadata
with lock, torch.no_grad():
template.values()[:] = values
grad_other_matrix = torch.sparse.mm(template.t(), grad_sparse_tensor)
ii, jj = template.indices()
grad_values = torch.sum(grad_sparse_tensor[ii] * other_matrix[jj], dim=-1)
# Note: if we run out of memory, you can accumulate this sum in chunks
return None, None, grad_values, grad_other_matrix
if __name__ == '__main__':
print(end='running tests: ')
for i in range(100):
print(end='.')
dim1, dim2, nnz = 10 ** 4, 10 ** 4, 10 ** 6
values = nn.Parameter(torch.randn(nnz), requires_grad=True).to(device)
indices = torch.stack([torch.randint(0, dim1, [len(values)]),
torch.randint(0, dim2, [len(values)])], dim=0).to(device)
our = TrainableSparseMatrix(indices, values, size=(dim1, dim2))
for i in range(5):
ref_values = our.values.detach().requires_grad_(True)
ref = torch.sparse_coo_tensor(our.template.indices().clone(), ref_values,
size=(dim1, dim2), requires_grad=True)
if our.values.grad is not None:
our.values.grad.zero_()
x = torch.randn(our.template.shape[1], 3, device=device, requires_grad=True)
y = our.mm(x) + our.mm(x + 1)
z = torch.rand_like(y)
(y * z).sum().backward()
if ref_values.grad is not None:
ref_values.grad.zero_()
x_clone = x.detach().requires_grad_(True)
y = torch.sparse.mm(ref, x_clone) + torch.sparse.mm(ref, x_clone + 1)
(y * z).sum().backward()
assert torch.allclose(our.values.grad, ref_values.grad, atol=1e-5)
assert torch.allclose(x.grad, x_clone.grad, atol=1e-5)
print(" passed!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment