Created
August 9, 2020 15:54
-
-
Save yzhangcs/752d1dad6a6a0a2a02081015e8eda952 to your computer and use it in GitHub Desktop.
Affine layers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Biaffine(nn.Module): | |
| """ | |
| Biaffine layer for first-order scoring. | |
| This function has a tensor of weights `W` and bias terms if needed. | |
| The score `s(x, y)` of the vector pair `(x, y)` is computed as :math: `x^T W y`, | |
| in which `x` and `y` can be concatenated with bias terms. | |
| References: | |
| - Timothy Dozat and Christopher D. Manning. 2017. | |
| `Deep Biaffine Attention for Neural Dependency Parsing`_. | |
| Args: | |
| n_in (int): | |
| Size of the input feature. | |
| bias_x (bool): | |
| If ``True``, add a bias term for tensor x. Default: ``True``. | |
| bias_y (bool): | |
| If ``True``, add a bias term for tensor y. Default: ``True``. | |
| .. _Deep Biaffine Attention for Neural Dependency Parsing: | |
| https://openreview.net/pdf?id=Hk95PK9le | |
| """ | |
| def __init__(self, n_in, bias_x=True, bias_y=True): | |
| super().__init__() | |
| self.n_in = n_in | |
| self.bias_x = bias_x | |
| self.bias_y = bias_y | |
| self.weight = nn.Parameter(torch.Tensor(n_in+bias_x, n_in+bias_y)) | |
| self.reset_parameters() | |
| def extra_repr(self): | |
| s = f"n_in={self.n_in}" | |
| if self.bias_x: | |
| s += f", bias_x={self.bias_x}" | |
| if self.bias_y: | |
| s += f", bias_y={self.bias_y}" | |
| return s | |
| def reset_parameters(self): | |
| nn.init.zeros_(self.weight) | |
| def forward(self, x, y): | |
| """ | |
| Args: | |
| x (torch.Tensor): ``[batch_size, seq_len, n_in]``. | |
| y (torch.Tensor): ``[batch_size, seq_len, n_in]``. | |
| Returns: | |
| s (torch.Tensor): ``[batch_size, seq_len, seq_len]``. | |
| """ | |
| if self.bias_x: | |
| x = torch.cat((x, torch.ones_like(x[..., :1])), -1) | |
| if self.bias_y: | |
| y = torch.cat((y, torch.ones_like(y[..., :1])), -1) | |
| # [batch_size, seq_len, seq_len] | |
| s = torch.einsum('bxi,ij,byj->bxy', x, self.weight, y) | |
| return s | |
| class Bilinear(nn.Module): | |
| """ | |
| Applies a bilinear transformation to the incoming data: :math:`y = x_1^T W x_2`. | |
| Args: | |
| n_in (int): | |
| Size of each input feature. | |
| n_out (int): | |
| size of each output sample. | |
| bias1 (bool): | |
| If ``True``, add a bias term for tensor x. Default: ``True``. | |
| bias2 (bool): | |
| If ``True``, add a bias term for tensor y. Default: ``True``. | |
| """ | |
| def __init__(self, n_in, n_out, bias1=True, bias2=True): | |
| super().__init__() | |
| self.n_in = n_in | |
| self.n_out = n_out | |
| self.bias1 = bias1 | |
| self.bias2 = bias2 | |
| self.weight = nn.Parameter(torch.Tensor(n_out, n_in+bias1, n_in+bias2)) | |
| self.reset_parameters() | |
| def extra_repr(self): | |
| s = f"n_in={self.n_in}, n_out={self.n_out}" | |
| if self.bias1: | |
| s += f", bias1={self.bias1}" | |
| if self.bias2: | |
| s += f", bias2={self.bias2}" | |
| return s | |
| def reset_parameters(self): | |
| nn.init.zeros_(self.weight) | |
| def forward(self, x1, x2): | |
| if self.bias1: | |
| x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1) | |
| if self.bias2: | |
| x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1) | |
| return F.bilinear(x1, x2, self.weight) | |
| class Triaffine(nn.Module): | |
| """ | |
| Triaffine layer for second-order scoring. | |
| This function has a tensor of weights `W` and bias terms if needed. | |
| The score `s(x, y, z)` of the vector triple `(x, y, z)` is computed as `x^T z^T W y`. | |
| Usually, `x` and `y` can be concatenated with bias terms. | |
| References: | |
| - Yu Zhang, Zhenghua Li and Min Zhang. 2020. | |
| `Efficient Second-Order TreeCRF for Neural Dependency Parsing`_. | |
| - Xinyu Wang, Jingxian Huang, and Kewei Tu. 2019. | |
| `Second-Order Semantic Dependency Parsing with End-to-End Neural Networks`_. | |
| Args: | |
| n_in (int): | |
| Size of the input feature. | |
| bias_x (bool): | |
| If ``True``, add a bias term for tensor x. Default: ``False``. | |
| bias_y (bool): | |
| If ``True``, add a bias term for tensor y. Default: ``False``. | |
| .. _Efficient Second-Order TreeCRF for Neural Dependency Parsing: | |
| https://www.aclweb.org/anthology/2020.acl-main.302/ | |
| .. _Second-Order Semantic Dependency Parsing with End-to-End Neural Networks: | |
| https://www.aclweb.org/anthology/P19-1454/ | |
| """ | |
| def __init__(self, n_in, bias_x=False, bias_y=False): | |
| super().__init__() | |
| self.n_in = n_in | |
| self.bias_x = bias_x | |
| self.bias_y = bias_y | |
| self.weight = nn.Parameter(torch.Tensor(n_in + bias_x, | |
| n_in, | |
| n_in + bias_y)) | |
| self.reset_parameters() | |
| def extra_repr(self): | |
| s = f"n_in={self.n_in}" | |
| if self.bias_x: | |
| s += f", bias_x={self.bias_x}" | |
| if self.bias_y: | |
| s += f", bias_y={self.bias_y}" | |
| return s | |
| def reset_parameters(self): | |
| nn.init.zeros_(self.weight) | |
| def forward(self, x, y, z): | |
| """ | |
| Args: | |
| x (torch.Tensor): ``[batch_size, seq_len, n_in]``. | |
| y (torch.Tensor): ``[batch_size, seq_len, n_in]``. | |
| z (torch.Tensor): ``[batch_size, seq_len, n_in]``. | |
| Returns: | |
| s (torch.Tensor): ``[batch_size, seq_len, seq_len, seq_len]``. | |
| """ | |
| if self.bias_x: | |
| x = torch.cat((x, torch.ones_like(x[..., :1])), -1) | |
| if self.bias_y: | |
| y = torch.cat((y, torch.ones_like(y[..., :1])), -1) | |
| w = torch.einsum('bzk,ikj->bzij', z, self.weight) | |
| # [batch_size, seq_len, seq_len, seq_len] | |
| s = torch.einsum('bxi,bzij,byj->bzxy', x, w, y) | |
| return s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment