Last active
February 27, 2024 13:14
-
-
Save Luxter77/dd16d741cbdb9cbcc96b5aa698c6b797 to your computer and use it in GitHub Desktop.
Torch Sequence Funnel Layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor | |
import torch.nn as nn | |
class FunnelLayer(nn.Module): | |
def __init__(self, input_length: int, output_length: int, hidden_size: int, conv_size: int, num_heads, deconv: bool = False): | |
""" | |
Initialize the FunnelLayer. | |
Args: | |
input_length (int): The length of the input sequence. | |
output_length (int): The desired length of the output sequence. | |
hidden_size (int): The dimensionality of the hidden states. | |
conv_size (int): The kernel size for the convolutional layers. | |
num_heads (int): The number of attention heads in the transformer encoder. Default is 8. | |
deconv (bool): If the Funnel goes in (sparse to dense) or out (dense to sparese). Default is False (sparse to dense) | |
""" | |
super(FunnelLayer, self).__init__() | |
self.input_length = input_length | |
self.output_length = output_length | |
self.hidden_size = hidden_size | |
self.conv_size = conv_size | |
self.num_heads = num_heads | |
self.deconv = deconv | |
# m[i] = Convolution(x[i - 1]) | |
# n[i] = Convolution(TransformerEncoder(x[i - 1])) | |
# y[i] = LayerNorm(m[i] + n[i]) | |
# x[i] = LayerNorm(Linear(y[i]) + y[i])) | |
self.stride = int(1 if (self.conv_size % 2 != 0) else 2) | |
self.padding = int((self.conv_size - self.stride) / 2) | |
if not self.deconv: | |
self.conv_m = nn.ConvTranspose1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride) | |
self.conv_n = nn.ConvTranspose1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride) | |
else: | |
self.conv_m = nn.Conv1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride) | |
self.conv_n = nn.Conv1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride) | |
try: | |
self.trns_n = nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=self.num_heads, dim_feedforward=self.hidden_size, activation="gelu", batch_first=True) | |
except AssertionError as ass: | |
if "embed_dim must be divisible by num_heads" in str(ass): | |
raise Exception(f"embed_dim (d_model={self.hidden_size}) must be divisible by nhead=({self.num_heads})") from ass | |
self.norm_y = nn.LayerNorm(self.hidden_size) | |
self.norm_x = nn.LayerNorm(self.hidden_size) | |
self.line_x = nn.Linear(self.hidden_size, self.hidden_size) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward pass of the FunnelLayer. | |
Args: | |
x (torch.Tensor): Input tensor of shape (batch_size, input_lenght, hidden_size). | |
Returns: | |
torch.Tensor: Output tensor of shape (batch_size, output_length, hidden_size). | |
""" | |
x_m = self.conv_m(x) | |
x_t = self.trns_n(x) | |
x_n = self.conv_n(x_t) | |
y_no = self.norm_y(x_m + x_n) | |
y_ff = self.line_x(y_no) | |
x_i = self.norm_x(y_ff + y_no) | |
return x_i |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment