Skip to content

Instantly share code, notes, and snippets.

@Luxter77
Last active February 27, 2024 13:14
Show Gist options
  • Save Luxter77/dd16d741cbdb9cbcc96b5aa698c6b797 to your computer and use it in GitHub Desktop.
Save Luxter77/dd16d741cbdb9cbcc96b5aa698c6b797 to your computer and use it in GitHub Desktop.
Torch Sequence Funnel Layer
from torch import Tensor
import torch.nn as nn
class FunnelLayer(nn.Module):
def __init__(self, input_length: int, output_length: int, hidden_size: int, conv_size: int, num_heads, deconv: bool = False):
"""
Initialize the FunnelLayer.
Args:
input_length (int): The length of the input sequence.
output_length (int): The desired length of the output sequence.
hidden_size (int): The dimensionality of the hidden states.
conv_size (int): The kernel size for the convolutional layers.
num_heads (int): The number of attention heads in the transformer encoder. Default is 8.
deconv (bool): If the Funnel goes in (sparse to dense) or out (dense to sparese). Default is False (sparse to dense)
"""
super(FunnelLayer, self).__init__()
self.input_length = input_length
self.output_length = output_length
self.hidden_size = hidden_size
self.conv_size = conv_size
self.num_heads = num_heads
self.deconv = deconv
# m[i] = Convolution(x[i - 1])
# n[i] = Convolution(TransformerEncoder(x[i - 1]))
# y[i] = LayerNorm(m[i] + n[i])
# x[i] = LayerNorm(Linear(y[i]) + y[i]))
self.stride = int(1 if (self.conv_size % 2 != 0) else 2)
self.padding = int((self.conv_size - self.stride) / 2)
if not self.deconv:
self.conv_m = nn.ConvTranspose1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride)
self.conv_n = nn.ConvTranspose1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride)
else:
self.conv_m = nn.Conv1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride)
self.conv_n = nn.Conv1d(self.input_length, self.output_length, kernel_size=self.conv_size, padding=self.padding, stride=self.stride)
try:
self.trns_n = nn.TransformerEncoderLayer(d_model=self.hidden_size, nhead=self.num_heads, dim_feedforward=self.hidden_size, activation="gelu", batch_first=True)
except AssertionError as ass:
if "embed_dim must be divisible by num_heads" in str(ass):
raise Exception(f"embed_dim (d_model={self.hidden_size}) must be divisible by nhead=({self.num_heads})") from ass
self.norm_y = nn.LayerNorm(self.hidden_size)
self.norm_x = nn.LayerNorm(self.hidden_size)
self.line_x = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the FunnelLayer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, input_lenght, hidden_size).
Returns:
torch.Tensor: Output tensor of shape (batch_size, output_length, hidden_size).
"""
x_m = self.conv_m(x)
x_t = self.trns_n(x)
x_n = self.conv_n(x_t)
y_no = self.norm_y(x_m + x_n)
y_ff = self.line_x(y_no)
x_i = self.norm_x(y_ff + y_no)
return x_i
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment