Last active
May 18, 2021 17:11
-
-
Save visionscaper/3321bbd54b59e4d11ee489a074a7e1aa to your computer and use it in GitHub Desktop.
Shared code for GPT-2 demo chatbot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This gist contains shared code for the Colab demo-chatbot.ipynb and demo-chatbot-inference.ipynb | |
# Copyright Globescope and Freddy Snijder. | |
# License "GNU General Public License v3.0" | |
# Also see https://choosealicense.com/licenses/gpl-3.0/ | |
from enum import Enum | |
import re | |
from html import unescape | |
from html.parser import HTMLParser | |
class MLStripper(HTMLParser): | |
""" | |
From https://stackoverflow.com/a/925630/889617 | |
""" | |
def __init__(self): | |
self.reset() | |
self.strict = False | |
self.convert_charrefs= True | |
self.fed = [] | |
def handle_data(self, d): | |
self.fed.append(d) | |
def get_data(self): | |
return ''.join(self.fed) | |
def strip_tags(line): | |
line = unescape(line) | |
s = MLStripper() | |
s.feed(line) | |
stripped_line = s.get_data() | |
return stripped_line | |
def filter_html_tags(conversations): | |
filtered_conversations = [] | |
for conversation in conversations: | |
filtered_conversation = [] | |
for line in conversation: | |
filtered_conversation.append(strip_tags(line)) | |
filtered_conversations.append(filtered_conversation) | |
return filtered_conversations | |
def replace_multiple_spaces(conversations): | |
filtered_conversations = [] | |
for conversation in conversations: | |
filtered_conversation = [] | |
for line in conversation: | |
line = re.sub(r"[\s]{2,}", " ", line) | |
filtered_conversation.append(line) | |
filtered_conversations.append(filtered_conversation) | |
return filtered_conversations | |
def remove_new_lines(conversations): | |
filtered_conversations = [] | |
for conversation in conversations: | |
filtered_conversation = [] | |
for line in conversation: | |
line = re.sub(r"[\n]+", "", line) | |
filtered_conversation.append(line) | |
filtered_conversations.append(filtered_conversation) | |
return filtered_conversations | |
def capitalize(actor_name): | |
name_parts = actor_name.split(" ") | |
name_parts = [name_part.capitalize() for name_part in name_parts] | |
return " ".join(name_parts) | |
class InferenceMode(Enum): | |
NONE = -1 | |
START = 0 | |
CONTINUATION = 1 | |
class SequenceBuilder: | |
def __init__(self, | |
bos_token_id, eos_token_id, | |
boa_token_id, eoa_token_id, | |
boc_token_id, eoc_token_id, | |
actor_segment_id, chat_segment_id, | |
ignore_token_id=-100, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.bos_token_id = bos_token_id | |
self.eos_token_id = eos_token_id | |
self.boa_token_id = boa_token_id | |
self.eoa_token_id = eoa_token_id | |
self.boc_token_id = boc_token_id | |
self.eoc_token_id = eoc_token_id | |
self.actor_segment_id = actor_segment_id | |
self.chat_segment_id = chat_segment_id | |
self.ignore_token_id = ignore_token_id | |
def add_separators_to_actor(self, name_ids): | |
return [self.boa_token_id] + name_ids + [self.eoa_token_id] | |
def add_separators_to_chat(self, chat_ids): | |
return [self.boc_token_id] + chat_ids + [self.eoc_token_id] | |
def __call__(self, | |
conversation_ids, | |
actor_ids, | |
inference=InferenceMode.NONE): | |
""" | |
Build model input sequences. | |
:param conversation_ids: List of chats, each chat is a list of token ids | |
NOTE: when doing inference, we do have the last actor name (bot name), but not | |
the last chat. In that case the last chat should be set to None. | |
:param actor_ids: List of actor_ids, for each chat in conversation_ids an | |
actor name as a list of token ids | |
:param inference: InferenceMode Enum | |
InferenceMode.NONE (default) We are in training mode | |
* Sequence starts with <bos> token | |
* Sequence is terminated with <eos> token | |
* target_labels are generated | |
InferenceMode.START Conversation started in inference mode | |
* Sequence starts with <bos> token | |
* Sequence is NOT terminated with <eos> token | |
* target_labels are NOT generated | |
InferenceMode.CONTINUATION Conversation continuation in inference mode | |
* Sequence doe NOT start with <bos> token | |
* Sequence is NOT terminated with <eos> token | |
* target_labels are NOT generated | |
:return: dict: | |
{ | |
"input_ids": <List of lists>, | |
"segment_ids": <List of lists>, | |
"target_labels": <None or List of lists> | |
} | |
To flatten the sequences you can do, e.g., list(chain(*input_ids)) | |
input_ids: | |
[Begin ] <bos> # InferenceMode.NONE or InferenceMode.START | |
[History ] <boa><actor name 0><eoa><boc><chat 0><eoc> ... <boa><actor name N-1><eoa><boc><chat N-1><eoc> | |
<boa><actor name N><eoa> | |
[Response] <boc><chat N><eoc> | |
[End ] <eos> # InferenceMode.NONE | |
segment_ids: | |
[Begin ] <actor> # InferenceMode.NONE or InferenceMode.START | |
[History ] <actor> ... <actor><chat> ... <chat> ... <actor> ... <actor><chat> ... <chat><actor> ... <actor> | |
[Response] <chat> ... <chat> | |
[End ] <chat> # InferenceMode.NONE | |
target_labels: # Only when InferenceMode.NONE | |
[Begin ] -100 | |
[History ] -100 ... -100 | |
[Response] -100 <chat N> <eoc> | |
[End ] -100 | |
""" | |
num_chats = len(conversation_ids) | |
input_ids = [] | |
segment_ids = [] | |
if inference in (InferenceMode.NONE, InferenceMode.START): | |
input_ids += [[self.bos_token_id]] | |
segment_ids += [[self.actor_segment_id]] | |
if inference is InferenceMode.NONE: | |
target_labels = [[self.ignore_token_id]] | |
else: | |
target_labels = None | |
for chat_idx, (chat_ids, chat_actor_ids) in enumerate(zip(conversation_ids, actor_ids)): | |
# ################# ADD ACTOR ################# | |
# <boa><actor name i><eoa> | |
input_ids += [self.add_separators_to_actor(chat_actor_ids)] | |
num_actor_tokens = len(chat_actor_ids) | |
# <actor> ... <actor> | |
segment_ids += [[self.actor_segment_id] * (num_actor_tokens + 2)] | |
if inference is InferenceMode.NONE: | |
# -100 ... -100 | |
target_labels += [[self.ignore_token_id] * (num_actor_tokens + 2)] | |
# ############################################# | |
if chat_ids is not None: | |
# ################## ADD CHAT ################# | |
# <boc><chat i><eoc> | |
input_ids += [self.add_separators_to_chat(chat_ids)] | |
num_chat_tokens = len(chat_ids) | |
# <chat> ... <chat> | |
segment_ids += [[self.chat_segment_id] * (num_chat_tokens + 2)] | |
if inference is InferenceMode.NONE: | |
if chat_idx < num_chats - 1: | |
# -100 ... -100 | |
target_labels += [[self.ignore_token_id] * (num_chat_tokens + 2)] | |
else: | |
# -100 <chat N> -100 | |
target_labels += [[self.ignore_token_id] + chat_ids + [self.eoc_token_id]] | |
# ############################################# | |
elif chat_idx < num_chats - 1: | |
raise ValueError("None found for chat before last chat. " | |
"Only the last chat can be None, when doing inference") | |
if inference is InferenceMode.NONE: | |
# Terminate sequence | |
input_ids += [[self.eos_token_id]] | |
segment_ids += [[self.chat_segment_id]] | |
target_labels += [[self.ignore_token_id]] | |
return { | |
"input_ids": input_ids, | |
"segment_ids": segment_ids, | |
"target_labels": target_labels | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment