Created
May 15, 2021 15:20
-
-
Save yzhangcs/2e1be43864e3937b824c6a8c9151ab53 to your computer and use it in GitHub Desktop.
Semirings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from functools import reduce | |
import torch | |
import torch.autograd as autograd | |
from supar.utils.common import MIN | |
from torch.autograd import Function | |
class Semiring(object): | |
r""" | |
A semiring is defined by a tuple `<K, +, ×, 0, 1>` :cite:`goodman-1999-semiring`. | |
`K` is a set of values; | |
`+` is commutative, associative and has an identity element `0`; | |
`×` is associative, has an identity element `1` and distributes over `+`. | |
""" | |
zero = 0 | |
one = 1 | |
@classmethod | |
def sum(cls, x, dim=-1): | |
return x.sum(dim) | |
@classmethod | |
def mul(cls, x, y): | |
return x * y | |
@classmethod | |
def dot(cls, x, y, dim=-1): | |
return cls.sum(cls.mul(x, y), dim) | |
@classmethod | |
def prod(cls, x, dim=-1): | |
return x.prod(dim) | |
@classmethod | |
def times(cls, *x): | |
return reduce(lambda i, j: cls.mul(i, j), x) | |
@classmethod | |
def zero_(cls, x): | |
return x.fill_(cls.zero) | |
@classmethod | |
def one_(cls, x): | |
return x.fill_(cls.one) | |
@classmethod | |
def zero_mask(cls, x, mask): | |
return x.masked_fill(mask, cls.zero) | |
@classmethod | |
def zero_mask_(cls, x, mask): | |
return x.masked_fill_(mask, cls.zero) | |
@classmethod | |
def one_mask(cls, x, mask): | |
return x.masked_fill(mask, cls.one) | |
@classmethod | |
def one_mask_(cls, x, mask): | |
return x.masked_fill_(mask, cls.one) | |
@classmethod | |
def convert(cls, x): | |
return x | |
@classmethod | |
def unconvert(cls, x): | |
return x | |
class LogSemiring(Semiring): | |
zero = MIN | |
one = 0 | |
@classmethod | |
def sum(cls, x, dim=-1): | |
return x.logsumexp(dim) | |
@classmethod | |
def mul(cls, x, y): | |
return x + y | |
@classmethod | |
def prod(cls, x, dim=-1): | |
return x.sum(dim) | |
class MaxSemiring(LogSemiring): | |
@classmethod | |
def sum(cls, x, dim=-1): | |
return x.max(dim)[0] | |
def KMaxSemiring(k): | |
class KMaxSemiring(LogSemiring): | |
@classmethod | |
def convert(cls, x): | |
return torch.cat((x.unsqueeze(0), cls.zero_(x.new_empty(k - 1, *x.shape)))) | |
@classmethod | |
def sum(cls, x, dim=-1): | |
dim = dim if dim >= 0 else x.dim() + dim | |
x = x.permute(dim, *range(dim), *range(dim + 1, x.dim())) | |
return x.reshape(-1, *x.shape[2:]).topk(k, 0)[0] | |
@classmethod | |
def mul(cls, x, y): | |
return (x.unsqueeze(0) + y.unsqueeze(1)).reshape(-1, *x.shape[1:]).topk(k, 0)[0] | |
@classmethod | |
def one_(cls, x): | |
x[:1].fill_(cls.one) | |
x[1:].fill_(cls.zero) | |
return x | |
return KMaxSemiring | |
class EntropySemiring(LogSemiring): | |
@classmethod | |
def convert(cls, x): | |
return torch.stack((x, torch.zeros_like(x))) | |
@classmethod | |
def unconvert(cls, x): | |
return x[-1] | |
@classmethod | |
def sum(cls, x, dim=-1): | |
p = x[0].logsumexp(dim) | |
r = x[0] - p.unsqueeze(dim) | |
r = r.exp().mul((x[1] - r)).sum(dim) | |
return torch.stack((p, r)) | |
@classmethod | |
def mul(cls, x, y): | |
return x + y | |
@classmethod | |
def zero_(cls, x): | |
x[:-1].fill_(cls.zero) | |
x[-1].fill_(cls.one) | |
return x | |
@classmethod | |
def one_(cls, x): | |
return x.fill_(cls.one) | |
class CrossEntropySemiring(LogSemiring): | |
@classmethod | |
def convert(cls, x): | |
return torch.cat((x, cls.one_(torch.empty_like(x[:1])))) | |
@classmethod | |
def unconvert(cls, x): | |
return x[-1] | |
@classmethod | |
def sum(cls, x, dim=-1): | |
p = x[:-1].logsumexp(dim) | |
r = x[:-1] - p.unsqueeze(dim) | |
r = r[0].exp().mul((x[-1] - r[1])).sum(dim) | |
return torch.cat((p, r.unsqueeze(0))) | |
@classmethod | |
def mul(cls, x, y): | |
return x + y | |
@classmethod | |
def zero_(cls, x): | |
x[:-1].fill_(cls.zero) | |
x[-1].fill_(cls.one) | |
return x | |
@classmethod | |
def one_(cls, x): | |
return x.fill_(cls.one) | |
class KLDivergenceSemiring(LogSemiring): | |
@classmethod | |
def convert(cls, x): | |
return torch.cat((x, cls.one_(torch.empty_like(x[:1])))) | |
@classmethod | |
def unconvert(cls, x): | |
return x[-1] | |
@classmethod | |
def sum(cls, x, dim=-1): | |
p = x[:-1].logsumexp(dim) | |
r = x[:-1] - p.unsqueeze(dim) | |
r = r[0].exp().mul((x[-1] - r[1] + r[0])).sum(dim) | |
return torch.cat((p, r.unsqueeze(0))) | |
@classmethod | |
def mul(cls, x, y): | |
return x + y | |
@classmethod | |
def zero_(cls, x): | |
x[:-1].fill_(cls.zero) | |
x[-1].fill_(cls.one) | |
return x | |
@classmethod | |
def one_(cls, x): | |
return x.fill_(cls.one) | |
class VarianceSemiring(LogSemiring): | |
@classmethod | |
def convert(cls, x): | |
return torch.cat((x, cls.one_(torch.empty_like(x[:1])))) | |
@classmethod | |
def unconvert(cls, x): | |
return x[-1] | |
@classmethod | |
def sum(cls, x, dim=-1): | |
p = x[:-1].logsumexp(dim) | |
r = x[:-1] - p.unsqueeze(dim) | |
r = r[0].exp().mul((x[-1] - r[1] + r[0])).sum(dim) | |
return torch.cat((p, r.unsqueeze(0))) | |
@classmethod | |
def mul(cls, x, y): | |
return x + y | |
@classmethod | |
def zero_(cls, x): | |
x[:-1].fill_(cls.zero) | |
x[-1].fill_(cls.one) | |
return x | |
@classmethod | |
def one_(cls, x): | |
return x.fill_(cls.one) | |
def CheckpointSemiring(semiring): | |
class CheckpointFunction(Function): | |
@staticmethod | |
def forward(ctx, func, *x): | |
ctx.func = func | |
ctx.save_for_backward(*[i for i in x if isinstance(i, torch.Tensor)]) | |
ctx.args = [i for i in x if not isinstance(i, torch.Tensor)] | |
with torch.no_grad(): | |
return func(*x) | |
@staticmethod | |
def backward(ctx, y): | |
x = ctx.saved_tensors | |
with torch.enable_grad(): | |
y = list(autograd.grad(ctx.func(*x, *ctx.args), [i for i in x if i.requires_grad], y)) | |
grads = [None] | |
for i in x: | |
grads.append(y.pop(0) if isinstance(i, torch.Tensor) and i.requires_grad else None) | |
grads += [None] * len(ctx.args) | |
return tuple(grads) | |
class CheckpointSemiring(semiring): | |
@classmethod | |
def sum(cls, x, dim=-1): | |
return CheckpointFunction.apply(semiring.sum, x, dim) | |
return CheckpointSemiring |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment