This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pprint import pprint | |
import tensorflow as tf | |
tf_path = os.path.abspath('./models/117M/model.ckpt') # Path to our TensorFlow checkpoint | |
tf_vars = tf.train.list_variables(tf_path) | |
pprint(tf_vars) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
git clone https://github.com/openai/gpt-2.git | |
cd gpt-2 | |
python download_model.py 117M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2019-present, Thomas Wolf. | |
# All rights reserved. This source code is licensed under the MIT-style license. | |
""" A very small and self-contained gist to train a GPT-2 transformer model on wikitext-103 """ | |
import os | |
from collections import namedtuple | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from ignite.engine import Engine, Events |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pytorch_pretrained_bert import cached_path | |
url = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json" | |
# Download and load JSON dataset | |
personachat_file = cached_path(url) | |
with open(personachat_file, "r", encoding="utf-8") as f: | |
dataset = json.loads(f.read()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Forward pass | |
lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids) | |
# Total loss as a weighted sum | |
lm_coef = 2.0 | |
mc_coef = 1.0 | |
total_loss = lm_loss * lm_coef + mc_loss * mc_coef |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Let's add a distractor to our previously defined persona, history and reply | |
distractor = ["sorry", "to", "hear", "that"] | |
# Build & tokenize inputs ending with our distractor like we did with the gold reply | |
words_distractor, segments_distractor, _, _ = build_inputs(persona, history, distractor) | |
words_distractor = tokenizer.convert_tokens_to_ids(words_distractor) | |
segments_distractor = tokenizer.convert_tokens_to_ids(segments_distractor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We will use 5 special tokens: | |
# - <bos> to indicate the start of the sequence | |
# - <eos> to indicate the end of the sequence | |
# - <speaker1> to indicate the beginning and the tokens of an utterance from the user | |
# - <speaker2> to indicate the beginning and the tokens of an utterance from the bot | |
# - <pad> as a padding token to build batches of sequences | |
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"] | |
# We can add these special tokens to the vocabulary and the embeddings of the model: | |
tokenizer.set_special_tokens(SPECIAL_TOKENS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import chain | |
# Let's define our contexts and special tokens | |
persona = [["i", "like", "playing", "football", "."], | |
["i", "am", "from", "NYC", "."]] | |
history = [["hello", "how", "are", "you", "?"], | |
["i", "am", "fine", "thanks", "."]] | |
reply = ["great", "to", "hear"] | |
bos, eos, speaker1, speaker2 = "<bos>", "<eos>", "<speaker1>", "<speaker2>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer | |
model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt') | |
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') |