Created
October 8, 2022 13:46
-
-
Save vatsalsaglani/b9d9e2608b8a0ca95ac7095e64cd447a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from requests import head | |
import torch as T | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from modules import Encoder, Decoder | |
class RecommendationTransformer(nn.Module): | |
"""Sequential recommendation model architecture | |
""" | |
def __init__(self, | |
vocab_size, | |
heads=4, | |
layers=6, | |
emb_dim=256, | |
pad_id=0, | |
num_pos=128): | |
super().__init__() | |
"""Recommendation model initializer | |
Args: | |
vocab_size (int): Number of unique tokens/items | |
heads (int, optional): Number of heads in the Multi-Head Self Attention Transformers (). Defaults to 4. | |
layers (int, optional): Number of Layers. Defaults to 6. | |
emb_dim (int, optional): Embedding Dimension. Defaults to 256. | |
pad_id (int, optional): Token used to pad tensors. Defaults to 0. | |
num_pos (int, optional): Positional Embedding, fixed sequence. Defaults to 128 | |
""" | |
self.emb_dim = emb_dim | |
self.pad_id = pad_id | |
self.num_pos = num_pos | |
self.vocab_size = vocab_size | |
self.encoder = Encoder(source_vocab_size=vocab_size, | |
emb_dim=emb_dim, | |
layers=layers, | |
heads=heads, | |
dim_model=emb_dim, | |
dim_inner=4 * emb_dim, | |
dim_value=emb_dim, | |
dim_key=emb_dim, | |
pad_id=self.pad_id, | |
num_pos=num_pos) | |
self.rec = nn.Linear(emb_dim, vocab_size) | |
def forward(self, source, source_mask): | |
enc_op = self.encoder(source, source_mask) | |
op = self.rec(enc_op) | |
return op.permute(0, 2, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment