Skip to content

Instantly share code, notes, and snippets.

@visionscaper
Last active May 18, 2021 17:11
Show Gist options
  • Save visionscaper/3321bbd54b59e4d11ee489a074a7e1aa to your computer and use it in GitHub Desktop.
Save visionscaper/3321bbd54b59e4d11ee489a074a7e1aa to your computer and use it in GitHub Desktop.
Shared code for GPT-2 demo chatbot
# This gist contains shared code for the Colab demo-chatbot.ipynb and demo-chatbot-inference.ipynb
# Copyright Globescope and Freddy Snijder.
# License "GNU General Public License v3.0"
# Also see https://choosealicense.com/licenses/gpl-3.0/
from enum import Enum
import re
from html import unescape
from html.parser import HTMLParser
class MLStripper(HTMLParser):
"""
From https://stackoverflow.com/a/925630/889617
"""
def __init__(self):
self.reset()
self.strict = False
self.convert_charrefs= True
self.fed = []
def handle_data(self, d):
self.fed.append(d)
def get_data(self):
return ''.join(self.fed)
def strip_tags(line):
line = unescape(line)
s = MLStripper()
s.feed(line)
stripped_line = s.get_data()
return stripped_line
def filter_html_tags(conversations):
filtered_conversations = []
for conversation in conversations:
filtered_conversation = []
for line in conversation:
filtered_conversation.append(strip_tags(line))
filtered_conversations.append(filtered_conversation)
return filtered_conversations
def replace_multiple_spaces(conversations):
filtered_conversations = []
for conversation in conversations:
filtered_conversation = []
for line in conversation:
line = re.sub(r"[\s]{2,}", " ", line)
filtered_conversation.append(line)
filtered_conversations.append(filtered_conversation)
return filtered_conversations
def remove_new_lines(conversations):
filtered_conversations = []
for conversation in conversations:
filtered_conversation = []
for line in conversation:
line = re.sub(r"[\n]+", "", line)
filtered_conversation.append(line)
filtered_conversations.append(filtered_conversation)
return filtered_conversations
def capitalize(actor_name):
name_parts = actor_name.split(" ")
name_parts = [name_part.capitalize() for name_part in name_parts]
return " ".join(name_parts)
class InferenceMode(Enum):
NONE = -1
START = 0
CONTINUATION = 1
class SequenceBuilder:
def __init__(self,
bos_token_id, eos_token_id,
boa_token_id, eoa_token_id,
boc_token_id, eoc_token_id,
actor_segment_id, chat_segment_id,
ignore_token_id=-100,
**kwargs):
super().__init__(**kwargs)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.boa_token_id = boa_token_id
self.eoa_token_id = eoa_token_id
self.boc_token_id = boc_token_id
self.eoc_token_id = eoc_token_id
self.actor_segment_id = actor_segment_id
self.chat_segment_id = chat_segment_id
self.ignore_token_id = ignore_token_id
def add_separators_to_actor(self, name_ids):
return [self.boa_token_id] + name_ids + [self.eoa_token_id]
def add_separators_to_chat(self, chat_ids):
return [self.boc_token_id] + chat_ids + [self.eoc_token_id]
def __call__(self,
conversation_ids,
actor_ids,
inference=InferenceMode.NONE):
"""
Build model input sequences.
:param conversation_ids: List of chats, each chat is a list of token ids
NOTE: when doing inference, we do have the last actor name (bot name), but not
the last chat. In that case the last chat should be set to None.
:param actor_ids: List of actor_ids, for each chat in conversation_ids an
actor name as a list of token ids
:param inference: InferenceMode Enum
InferenceMode.NONE (default) We are in training mode
* Sequence starts with <bos> token
* Sequence is terminated with <eos> token
* target_labels are generated
InferenceMode.START Conversation started in inference mode
* Sequence starts with <bos> token
* Sequence is NOT terminated with <eos> token
* target_labels are NOT generated
InferenceMode.CONTINUATION Conversation continuation in inference mode
* Sequence doe NOT start with <bos> token
* Sequence is NOT terminated with <eos> token
* target_labels are NOT generated
:return: dict:
{
"input_ids": <List of lists>,
"segment_ids": <List of lists>,
"target_labels": <None or List of lists>
}
To flatten the sequences you can do, e.g., list(chain(*input_ids))
input_ids:
[Begin ] <bos> # InferenceMode.NONE or InferenceMode.START
[History ] <boa><actor name 0><eoa><boc><chat 0><eoc> ... <boa><actor name N-1><eoa><boc><chat N-1><eoc>
<boa><actor name N><eoa>
[Response] <boc><chat N><eoc>
[End ] <eos> # InferenceMode.NONE
segment_ids:
[Begin ] <actor> # InferenceMode.NONE or InferenceMode.START
[History ] <actor> ... <actor><chat> ... <chat> ... <actor> ... <actor><chat> ... <chat><actor> ... <actor>
[Response] <chat> ... <chat>
[End ] <chat> # InferenceMode.NONE
target_labels: # Only when InferenceMode.NONE
[Begin ] -100
[History ] -100 ... -100
[Response] -100 <chat N> <eoc>
[End ] -100
"""
num_chats = len(conversation_ids)
input_ids = []
segment_ids = []
if inference in (InferenceMode.NONE, InferenceMode.START):
input_ids += [[self.bos_token_id]]
segment_ids += [[self.actor_segment_id]]
if inference is InferenceMode.NONE:
target_labels = [[self.ignore_token_id]]
else:
target_labels = None
for chat_idx, (chat_ids, chat_actor_ids) in enumerate(zip(conversation_ids, actor_ids)):
# ################# ADD ACTOR #################
# <boa><actor name i><eoa>
input_ids += [self.add_separators_to_actor(chat_actor_ids)]
num_actor_tokens = len(chat_actor_ids)
# <actor> ... <actor>
segment_ids += [[self.actor_segment_id] * (num_actor_tokens + 2)]
if inference is InferenceMode.NONE:
# -100 ... -100
target_labels += [[self.ignore_token_id] * (num_actor_tokens + 2)]
# #############################################
if chat_ids is not None:
# ################## ADD CHAT #################
# <boc><chat i><eoc>
input_ids += [self.add_separators_to_chat(chat_ids)]
num_chat_tokens = len(chat_ids)
# <chat> ... <chat>
segment_ids += [[self.chat_segment_id] * (num_chat_tokens + 2)]
if inference is InferenceMode.NONE:
if chat_idx < num_chats - 1:
# -100 ... -100
target_labels += [[self.ignore_token_id] * (num_chat_tokens + 2)]
else:
# -100 <chat N> -100
target_labels += [[self.ignore_token_id] + chat_ids + [self.eoc_token_id]]
# #############################################
elif chat_idx < num_chats - 1:
raise ValueError("None found for chat before last chat. "
"Only the last chat can be None, when doing inference")
if inference is InferenceMode.NONE:
# Terminate sequence
input_ids += [[self.eos_token_id]]
segment_ids += [[self.chat_segment_id]]
target_labels += [[self.ignore_token_id]]
return {
"input_ids": input_ids,
"segment_ids": segment_ids,
"target_labels": target_labels
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment