Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created May 15, 2021 15:20
Show Gist options
  • Save yzhangcs/2e1be43864e3937b824c6a8c9151ab53 to your computer and use it in GitHub Desktop.
Save yzhangcs/2e1be43864e3937b824c6a8c9151ab53 to your computer and use it in GitHub Desktop.
Semirings
# -*- coding: utf-8 -*-
from functools import reduce
import torch
import torch.autograd as autograd
from supar.utils.common import MIN
from torch.autograd import Function
class Semiring(object):
r"""
A semiring is defined by a tuple `<K, +, ×, 0, 1>` :cite:`goodman-1999-semiring`.
`K` is a set of values;
`+` is commutative, associative and has an identity element `0`;
`×` is associative, has an identity element `1` and distributes over `+`.
"""
zero = 0
one = 1
@classmethod
def sum(cls, x, dim=-1):
return x.sum(dim)
@classmethod
def mul(cls, x, y):
return x * y
@classmethod
def dot(cls, x, y, dim=-1):
return cls.sum(cls.mul(x, y), dim)
@classmethod
def prod(cls, x, dim=-1):
return x.prod(dim)
@classmethod
def times(cls, *x):
return reduce(lambda i, j: cls.mul(i, j), x)
@classmethod
def zero_(cls, x):
return x.fill_(cls.zero)
@classmethod
def one_(cls, x):
return x.fill_(cls.one)
@classmethod
def zero_mask(cls, x, mask):
return x.masked_fill(mask, cls.zero)
@classmethod
def zero_mask_(cls, x, mask):
return x.masked_fill_(mask, cls.zero)
@classmethod
def one_mask(cls, x, mask):
return x.masked_fill(mask, cls.one)
@classmethod
def one_mask_(cls, x, mask):
return x.masked_fill_(mask, cls.one)
@classmethod
def convert(cls, x):
return x
@classmethod
def unconvert(cls, x):
return x
class LogSemiring(Semiring):
zero = MIN
one = 0
@classmethod
def sum(cls, x, dim=-1):
return x.logsumexp(dim)
@classmethod
def mul(cls, x, y):
return x + y
@classmethod
def prod(cls, x, dim=-1):
return x.sum(dim)
class MaxSemiring(LogSemiring):
@classmethod
def sum(cls, x, dim=-1):
return x.max(dim)[0]
def KMaxSemiring(k):
class KMaxSemiring(LogSemiring):
@classmethod
def convert(cls, x):
return torch.cat((x.unsqueeze(0), cls.zero_(x.new_empty(k - 1, *x.shape))))
@classmethod
def sum(cls, x, dim=-1):
dim = dim if dim >= 0 else x.dim() + dim
x = x.permute(dim, *range(dim), *range(dim + 1, x.dim()))
return x.reshape(-1, *x.shape[2:]).topk(k, 0)[0]
@classmethod
def mul(cls, x, y):
return (x.unsqueeze(0) + y.unsqueeze(1)).reshape(-1, *x.shape[1:]).topk(k, 0)[0]
@classmethod
def one_(cls, x):
x[:1].fill_(cls.one)
x[1:].fill_(cls.zero)
return x
return KMaxSemiring
class EntropySemiring(LogSemiring):
@classmethod
def convert(cls, x):
return torch.stack((x, torch.zeros_like(x)))
@classmethod
def unconvert(cls, x):
return x[-1]
@classmethod
def sum(cls, x, dim=-1):
p = x[0].logsumexp(dim)
r = x[0] - p.unsqueeze(dim)
r = r.exp().mul((x[1] - r)).sum(dim)
return torch.stack((p, r))
@classmethod
def mul(cls, x, y):
return x + y
@classmethod
def zero_(cls, x):
x[:-1].fill_(cls.zero)
x[-1].fill_(cls.one)
return x
@classmethod
def one_(cls, x):
return x.fill_(cls.one)
class CrossEntropySemiring(LogSemiring):
@classmethod
def convert(cls, x):
return torch.cat((x, cls.one_(torch.empty_like(x[:1]))))
@classmethod
def unconvert(cls, x):
return x[-1]
@classmethod
def sum(cls, x, dim=-1):
p = x[:-1].logsumexp(dim)
r = x[:-1] - p.unsqueeze(dim)
r = r[0].exp().mul((x[-1] - r[1])).sum(dim)
return torch.cat((p, r.unsqueeze(0)))
@classmethod
def mul(cls, x, y):
return x + y
@classmethod
def zero_(cls, x):
x[:-1].fill_(cls.zero)
x[-1].fill_(cls.one)
return x
@classmethod
def one_(cls, x):
return x.fill_(cls.one)
class KLDivergenceSemiring(LogSemiring):
@classmethod
def convert(cls, x):
return torch.cat((x, cls.one_(torch.empty_like(x[:1]))))
@classmethod
def unconvert(cls, x):
return x[-1]
@classmethod
def sum(cls, x, dim=-1):
p = x[:-1].logsumexp(dim)
r = x[:-1] - p.unsqueeze(dim)
r = r[0].exp().mul((x[-1] - r[1] + r[0])).sum(dim)
return torch.cat((p, r.unsqueeze(0)))
@classmethod
def mul(cls, x, y):
return x + y
@classmethod
def zero_(cls, x):
x[:-1].fill_(cls.zero)
x[-1].fill_(cls.one)
return x
@classmethod
def one_(cls, x):
return x.fill_(cls.one)
class VarianceSemiring(LogSemiring):
@classmethod
def convert(cls, x):
return torch.cat((x, cls.one_(torch.empty_like(x[:1]))))
@classmethod
def unconvert(cls, x):
return x[-1]
@classmethod
def sum(cls, x, dim=-1):
p = x[:-1].logsumexp(dim)
r = x[:-1] - p.unsqueeze(dim)
r = r[0].exp().mul((x[-1] - r[1] + r[0])).sum(dim)
return torch.cat((p, r.unsqueeze(0)))
@classmethod
def mul(cls, x, y):
return x + y
@classmethod
def zero_(cls, x):
x[:-1].fill_(cls.zero)
x[-1].fill_(cls.one)
return x
@classmethod
def one_(cls, x):
return x.fill_(cls.one)
def CheckpointSemiring(semiring):
class CheckpointFunction(Function):
@staticmethod
def forward(ctx, func, *x):
ctx.func = func
ctx.save_for_backward(*[i for i in x if isinstance(i, torch.Tensor)])
ctx.args = [i for i in x if not isinstance(i, torch.Tensor)]
with torch.no_grad():
return func(*x)
@staticmethod
def backward(ctx, y):
x = ctx.saved_tensors
with torch.enable_grad():
y = list(autograd.grad(ctx.func(*x, *ctx.args), [i for i in x if i.requires_grad], y))
grads = [None]
for i in x:
grads.append(y.pop(0) if isinstance(i, torch.Tensor) and i.requires_grad else None)
grads += [None] * len(ctx.args)
return tuple(grads)
class CheckpointSemiring(semiring):
@classmethod
def sum(cls, x, dim=-1):
return CheckpointFunction.apply(semiring.sum, x, dim)
return CheckpointSemiring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment