Last active
October 30, 2021 07:46
-
-
Save MarcSkovMadsen/eae998fbcb299fae9e92ab0089e7eff8 to your computer and use it in GitHub Desktop.
Hugging Face GPT2 Transformer Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import tensorflow as tf | |
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer | |
from transformers import tf_top_k_top_p_filtering | |
import panel as pn | |
pn.extension() | |
import panel.widgets as pnw | |
from math import pi | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource | |
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer | |
pn.extension(sizing_mode="stretch_width") | |
ACCENT_BASE_COLOR = "#f37736" | |
THEME = pn.state.session_args.get("theme", [b"default"])[0].decode() | |
if THEME=="dark": | |
GENERATED_TEXT_BACKGROUND = "#181818" | |
else: | |
GENERATED_TEXT_BACKGROUND = "#f0f0f0" | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
logger.info("Setting Tensorflow random seed to 1234") | |
tf.random.set_seed(1234) | |
# tokenizer and model for word generation | |
# should only be loaded once as loading them takes ~5s. | |
# Thus we use caching to share between sessions | |
logger.info("Loading gpt2 tokenized ...") | |
if "gpt2-tokenizer" in pn.state.cache: | |
tokenizer = pn.state.cache["gpt2-tokenizer"] | |
else: | |
tokenizer = pn.state.cache["gpt2-tokenizer"] = GPT2Tokenizer.from_pretrained("gpt2") | |
logger.info("Loading gpt2 model ...") | |
if "gpt2-model" in pn.state.cache: | |
model = pn.state.cache["gpt2-model"] | |
else: | |
model = pn.state.cache["gpt2-model"] = TFGPT2LMHeadModel.from_pretrained( | |
"gpt2", pad_token_id=tokenizer.eos_token_id | |
) | |
def get_pred( | |
sequence="Please input some text", | |
model=model, | |
tokenizer=tokenizer, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
): | |
"""Returns the predicted words and logits to derive the probabilities for each prediction""" | |
tf.random.set_seed(1234) | |
input_ids = tokenizer.encode(sequence, return_tensors="tf") | |
# get logits of last hidden state | |
next_token_logits = model(input_ids)[0][:, -1, :] | |
# apply a temperature coefficient and filter | |
next_token_logits = next_token_logits / temperature | |
# filter | |
filtered_next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k, top_p) | |
# sample | |
next_token = tf.random.categorical(filtered_next_token_logits, dtype=tf.int32, num_samples=1) | |
resulting_string = tokenizer.decode(next_token.numpy().tolist()[0]) | |
return resulting_string, filtered_next_token_logits | |
def get_plot_data(filtered_next_token_logits): | |
"""Returns the data ready for plotting in Bokeh""" | |
probabilities = tf.nn.softmax(filtered_next_token_logits) | |
k = tf.math.count_nonzero(probabilities).numpy() | |
k = min(100, k) | |
probs_filter = tf.math.top_k(probabilities[0], k) | |
probability_list = probs_filter.values.numpy() | |
word_list = list() | |
for i in probs_filter.indices.numpy(): | |
word_list.append(tokenizer.decode([i])) | |
return probability_list, word_list | |
def clean_plot_data(word_list, probability_list): | |
"""Prepares the data for plotting | |
- Aggregates words that appear multiple times | |
""" | |
result = {} | |
for w, p in zip(word_list, probability_list): | |
if w not in result: | |
result[w] = p | |
else: | |
result[w] += p | |
sorted_keys = sorted(result, key=result.get, reverse=True) | |
result = {k: result[k] for k in sorted_keys} | |
return list(result.keys()), list(result.values()) | |
def get_plot(word_list, probability_list): | |
"""Returns a Bokeh plot""" | |
word_list, probability_list = clean_plot_data(word_list, probability_list) | |
source = ColumnDataSource(data=dict(word_list=word_list, probability_list=probability_list)) | |
plot = figure( | |
x_range=source.data["word_list"], | |
height=250, | |
title="Probabilities", | |
toolbar_location=None, | |
tools="", | |
) | |
plot.vbar( | |
x="word_list", top="probability_list", width=0.8, source=source, color=ACCENT_BASE_COLOR | |
) | |
plot.xaxis.major_label_orientation = pi / 2 | |
return plot | |
logger.info("Creating Widgets and Panes") | |
temperature_pn = pnw.FloatSlider(name="Temperature", value=1.0, start=0.0, end=1.0, step=0.01) | |
top_k_pn = pnw.IntSlider(name="Top K", value=0, start=0, end=100) | |
top_p_pn = pnw.FloatSlider(name="Top p", value=1.0, start=0.0, end=1.0, step=0.01) | |
settings = pn.Column(temperature_pn, top_k_pn, top_p_pn) | |
text_input = pn.widgets.TextInput(value="Enter a string here...") | |
generated_text = pn.pane.HTML( | |
object=text_input.value, background=GENERATED_TEXT_BACKGROUND, min_height=200, sizing_mode="stretch_both" | |
) | |
text_input.link(generated_text, value="object") | |
predict_button = pn.widgets.Button(name="▶ Predict", button_type="primary") | |
text_part = pn.Column(text_input, predict_button, generated_text) | |
bokeh_plot = pn.pane.Bokeh(sizing_mode="stretch_both") | |
def predict(event=None): | |
"""Runs the prediction, updates widgets and panes""" | |
# bokeh_plot.loading = True | |
pred, filtered_next_token_logits = get_pred( | |
generated_text.object, | |
model, | |
tokenizer, | |
temperature_pn.value, | |
top_k_pn.value, | |
top_p_pn.value, | |
) | |
generated_text.object += pred | |
probabilities, word_list = get_plot_data(filtered_next_token_logits) | |
probability_list = probabilities.tolist() | |
bokeh_plot.object = get_plot(word_list, probability_list) | |
# bokeh_plot.loading = False | |
predict() | |
predict_button.on_click(predict) | |
auto_predict_callback = pn.state.add_periodic_callback(predict, period=1000, start=False) | |
def text_change_cb(event): | |
generated_text.object = event.new | |
text_input.param.watch(text_change_cb, "value") | |
panel_logo_pane = pn.pane.PNG( | |
"https://panel.holoviz.org/_static/logo_stacked.png", | |
link_url="https://panel.holoviz.org", | |
embed=False, | |
height=115, | |
margin=25, | |
sizing_mode="fixed", | |
) | |
hugging_face_pane = pn.pane.PNG( | |
"https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/transformers_logo_name.png", | |
link_url="https://huggingface.co/", | |
embed=False, | |
height=115, | |
margin=25, | |
sizing_mode="fixed", | |
) | |
image_component = pn.layout.FlexBox( | |
panel_logo_pane, hugging_face_pane, | |
justify_content="center", | |
margin=25, | |
sizing_mode="stretch_both", | |
) | |
app = pn.template.FastListTemplate( | |
site="Awesome Panel", | |
title="Hugging Face Transformers", | |
sidebar=[ | |
"# ⚙️ Parameters", | |
settings, | |
"# 🏃 Auto Predict", | |
pn.Param(auto_predict_callback.param, parameters=["period", "running"], show_name=False), | |
""" | |
# 🎓 Info | |
**GPT-2** is a large *transformer-based* language model with 1.5 billion parameters, trained on a | |
dataset of 8 million web pages. | |
GPT-2 is trained with a simple objective: **predict the next word**, given all | |
of the previous words within some text.""", | |
], | |
main=[image_component, text_part, bokeh_plot], | |
accent_base_color=ACCENT_BASE_COLOR, | |
header_background=ACCENT_BASE_COLOR, | |
) | |
logger.info("Serving the App") | |
app.servable() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A few resources. Feel free to share on social media if you like