Skip to content

Instantly share code, notes, and snippets.

@lhr0909
Last active August 26, 2022 03:33
Show Gist options
  • Save lhr0909/48972ad1c776e51e4ca14aa2dc72e5a2 to your computer and use it in GitHub Desktop.
Save lhr0909/48972ad1c776e51e4ca14aa2dc72e5a2 to your computer and use it in GitHub Desktop.
DIET Classifier PyTorch Snippets
import torch
from torch import nn, Tensor
from .config import DIETClassifierConfig
class IntentClassifier(nn.Module):
def __init__(self, config: DIETClassifierConfig):
super().__init__()
# Rasa's embedding layer is actually a "dense embedding layer" which is just a Keras dense layer
# equivalent to a PyTorch Linear layer.
self.sentence_embed = nn.Linear(config.sentence_feature_dimension, config.embedding_dimension)
self.label_embed = nn.Linear(config.num_intents, config.embedding_dimension)
def forward(self, sentence_features: Tensor, label_features: Tensor):
sentence_embedding = self.sentence_embed(sentence_features)
label_embedding = self.label_embed(label_features)
# dot product similarities
similarities = torch.mm(sentence_embedding, label_embedding.t())
return similarities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment