Created
April 10, 2022 12:10
-
-
Save ayaka14732/273a11ff77e5e4ead98f6a2a07d86306 to your computer and use it in GitHub Desktop.
Restricted text generation with Hugging Face GPT-2 model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generate texts by only selecting the words that starts with the letter s | |
import jax | |
import jax.nn as nn | |
import jax.numpy as np | |
import jax.random as rand | |
import string | |
from transformers import GPT2Tokenizer, FlaxGPT2LMHeadModel | |
# filters | |
def only_top_p(logits: np.ndarray, p: float=0.9): | |
'''Mask all tokens with cumulative probability out of p.''' | |
*batch_sizes, vocab_size = logits.shape | |
logits_indices = np.arange(0, vocab_size, dtype=np.int32) | |
logits_indices = jax.lax.broadcast(logits_indices, batch_sizes) | |
sorted_logits, sorted_logits_indices = jax.lax.sort_key_val(logits, logits_indices, dimension=-1, is_stable=False) | |
sorted_logits = sorted_logits[:, ::-1] | |
sorted_logits_indices = sorted_logits_indices[:, ::-1] | |
sorted_logits_ = nn.softmax(sorted_logits) | |
sorted_logits_ = np.cumsum(sorted_logits_, axis=-1) | |
logits_ = np.where(sorted_logits_ < p, sorted_logits, np.NINF) | |
_, logits_ = jax.lax.sort_key_val(sorted_logits_indices, logits_, dimension=-1, is_stable=False) | |
return logits_ | |
@jax.vmap | |
def mask_repetition(logits: np.ndarray, x: np.ndarray): | |
'''Mask all tokens that have occurred in the sentence.''' | |
return logits.at[x].set(np.NINF) | |
# Generation | |
model_name = 'gpt2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = FlaxGPT2LMHeadModel.from_pretrained(model_name) | |
mask_starts_with_s = np.array([(token[0] == 'Ġ' and len(token) >= 2 and token[1].lower() == 's') \ | |
or (token[0] in string.ascii_letters) \ | |
for token in tokenizer.convert_ids_to_tokens(range(tokenizer.vocab_size))], dtype=np.bool_) | |
def generation_step(x: np.ndarray, key: rand.KeyArray): | |
outputs = model(input_ids=x) | |
logits = outputs.logits[:, -1] | |
logits = only_top_p(logits) | |
logits = mask_repetition(logits, x) | |
logits = np.where(mask_starts_with_s, logits, np.NINF) | |
assert not np.any(np.all(np.isneginf(logits), axis=-1)), 'all tokens have been masked' | |
y = rand.categorical(key, logits, axis=-1) | |
return y | |
key = rand.PRNGKey(42) | |
sentences = ['some'] | |
inputs = tokenizer(sentences, return_tensors='jax') | |
x = inputs.input_ids | |
for _ in range(6): | |
key, subkey = rand.split(key) | |
y = generation_step(x, subkey) | |
x = np.hstack((x, y[:, None])) | |
print(tokenizer.batch_decode(x)[0]) # some said she should stop selling seeds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment