Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created August 9, 2020 15:54
Show Gist options
  • Select an option

  • Save yzhangcs/752d1dad6a6a0a2a02081015e8eda952 to your computer and use it in GitHub Desktop.

Select an option

Save yzhangcs/752d1dad6a6a0a2a02081015e8eda952 to your computer and use it in GitHub Desktop.
Affine layers
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Biaffine(nn.Module):
"""
Biaffine layer for first-order scoring.
This function has a tensor of weights `W` and bias terms if needed.
The score `s(x, y)` of the vector pair `(x, y)` is computed as :math: `x^T W y`,
in which `x` and `y` can be concatenated with bias terms.
References:
- Timothy Dozat and Christopher D. Manning. 2017.
`Deep Biaffine Attention for Neural Dependency Parsing`_.
Args:
n_in (int):
Size of the input feature.
bias_x (bool):
If ``True``, add a bias term for tensor x. Default: ``True``.
bias_y (bool):
If ``True``, add a bias term for tensor y. Default: ``True``.
.. _Deep Biaffine Attention for Neural Dependency Parsing:
https://openreview.net/pdf?id=Hk95PK9le
"""
def __init__(self, n_in, bias_x=True, bias_y=True):
super().__init__()
self.n_in = n_in
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(torch.Tensor(n_in+bias_x, n_in+bias_y))
self.reset_parameters()
def extra_repr(self):
s = f"n_in={self.n_in}"
if self.bias_x:
s += f", bias_x={self.bias_x}"
if self.bias_y:
s += f", bias_y={self.bias_y}"
return s
def reset_parameters(self):
nn.init.zeros_(self.weight)
def forward(self, x, y):
"""
Args:
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
Returns:
s (torch.Tensor): ``[batch_size, seq_len, seq_len]``.
"""
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
# [batch_size, seq_len, seq_len]
s = torch.einsum('bxi,ij,byj->bxy', x, self.weight, y)
return s
class Bilinear(nn.Module):
"""
Applies a bilinear transformation to the incoming data: :math:`y = x_1^T W x_2`.
Args:
n_in (int):
Size of each input feature.
n_out (int):
size of each output sample.
bias1 (bool):
If ``True``, add a bias term for tensor x. Default: ``True``.
bias2 (bool):
If ``True``, add a bias term for tensor y. Default: ``True``.
"""
def __init__(self, n_in, n_out, bias1=True, bias2=True):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.bias1 = bias1
self.bias2 = bias2
self.weight = nn.Parameter(torch.Tensor(n_out, n_in+bias1, n_in+bias2))
self.reset_parameters()
def extra_repr(self):
s = f"n_in={self.n_in}, n_out={self.n_out}"
if self.bias1:
s += f", bias1={self.bias1}"
if self.bias2:
s += f", bias2={self.bias2}"
return s
def reset_parameters(self):
nn.init.zeros_(self.weight)
def forward(self, x1, x2):
if self.bias1:
x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1)
if self.bias2:
x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1)
return F.bilinear(x1, x2, self.weight)
class Triaffine(nn.Module):
"""
Triaffine layer for second-order scoring.
This function has a tensor of weights `W` and bias terms if needed.
The score `s(x, y, z)` of the vector triple `(x, y, z)` is computed as `x^T z^T W y`.
Usually, `x` and `y` can be concatenated with bias terms.
References:
- Yu Zhang, Zhenghua Li and Min Zhang. 2020.
`Efficient Second-Order TreeCRF for Neural Dependency Parsing`_.
- Xinyu Wang, Jingxian Huang, and Kewei Tu. 2019.
`Second-Order Semantic Dependency Parsing with End-to-End Neural Networks`_.
Args:
n_in (int):
Size of the input feature.
bias_x (bool):
If ``True``, add a bias term for tensor x. Default: ``False``.
bias_y (bool):
If ``True``, add a bias term for tensor y. Default: ``False``.
.. _Efficient Second-Order TreeCRF for Neural Dependency Parsing:
https://www.aclweb.org/anthology/2020.acl-main.302/
.. _Second-Order Semantic Dependency Parsing with End-to-End Neural Networks:
https://www.aclweb.org/anthology/P19-1454/
"""
def __init__(self, n_in, bias_x=False, bias_y=False):
super().__init__()
self.n_in = n_in
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(torch.Tensor(n_in + bias_x,
n_in,
n_in + bias_y))
self.reset_parameters()
def extra_repr(self):
s = f"n_in={self.n_in}"
if self.bias_x:
s += f", bias_x={self.bias_x}"
if self.bias_y:
s += f", bias_y={self.bias_y}"
return s
def reset_parameters(self):
nn.init.zeros_(self.weight)
def forward(self, x, y, z):
"""
Args:
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
z (torch.Tensor): ``[batch_size, seq_len, n_in]``.
Returns:
s (torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``.
"""
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
w = torch.einsum('bzk,ikj->bzij', z, self.weight)
# [batch_size, seq_len, seq_len, seq_len]
s = torch.einsum('bxi,bzij,byj->bzxy', x, w, y)
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment