Skip to content

Instantly share code, notes, and snippets.

@rgstephens
Last active May 18, 2021 12:47
Show Gist options
  • Save rgstephens/a32742a1fc47373f3c0f989c3061af3c to your computer and use it in GitHub Desktop.
Save rgstephens/a32742a1fc47373f3c0f989c3061af3c to your computer and use it in GitHub Desktop.
Rasa featurizer for Universal Sentence Encoder
from rasa_nlu.featurizers import Featurizer
import tensorflow_hub as hub
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
class UniversalSentenceEncoderFeaturizer(Featurizer):
"""Appends a universal sentence encoding to the message's text_features."""
TFHUB_URL = "https://tfhub.dev/google/universal-sentence-encoder/2"
name = "universal_sentence_encoder_featurizer"
# We don't require any previous pipline step and return text_features
requires = []
provides = ["text_features"]
def __init__(self, component_config):
import time
super(UniversalSentenceEncoderFeaturizer, self).__init__(component_config)
logger.debug("loading sentence encoder")
start_time = time.time()
print("loading sentence encoder")
# Load the TensorFlow Hub Module with pre-trained weights
sentence_encoder = hub.Module(self.TFHUB_URL)
elapsed_time = time.time() - start_time
print("load complete: %.1f seconds, continue setup..." % (elapsed_time))
start_time = time.time()
# Create a TensorFlow placeholder for the input string
self.input_string = tf.placeholder(tf.string, shape=[None])
# Invoke `sentence_encoder` in order to create the encoding tensor
self.encoding = sentence_encoder(self.input_string)
# Create a TensorFlow Session and run initializers
self.session = tf.Session()
self.session.run([tf.global_variables_initializer(),
tf.tables_initializer()])
elapsed_time = time.time() - start_time
print("tensorflow init complete: %.1f seconds" % (elapsed_time))
def train(self, training_data, config, **kwargs):
# Nothing to train, just process all training examples so that the
# feature is set for future pipeline steps
for example in training_data.training_examples:
self.process(example)
def process(self, message, **kwargs):
# Get the sentence encoding by feeding the message text and computing
# the encoding tensor.
feature_vector = self.session.run(self.encoding,
{self.input_string: [message.text]})[0]
# Concatenate the feature vector with any existing text features
features = self._combine_with_existing_features(message, feature_vector)
# Set the feature, overwriting any existing `text_features`
message.set("text_features", features)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment