Skip to content

Instantly share code, notes, and snippets.

@ayaka14732
Created April 10, 2022 12:10
Show Gist options
  • Save ayaka14732/273a11ff77e5e4ead98f6a2a07d86306 to your computer and use it in GitHub Desktop.
Save ayaka14732/273a11ff77e5e4ead98f6a2a07d86306 to your computer and use it in GitHub Desktop.
Restricted text generation with Hugging Face GPT-2 model
# Generate texts by only selecting the words that starts with the letter s
import jax
import jax.nn as nn
import jax.numpy as np
import jax.random as rand
import string
from transformers import GPT2Tokenizer, FlaxGPT2LMHeadModel
# filters
def only_top_p(logits: np.ndarray, p: float=0.9):
'''Mask all tokens with cumulative probability out of p.'''
*batch_sizes, vocab_size = logits.shape
logits_indices = np.arange(0, vocab_size, dtype=np.int32)
logits_indices = jax.lax.broadcast(logits_indices, batch_sizes)
sorted_logits, sorted_logits_indices = jax.lax.sort_key_val(logits, logits_indices, dimension=-1, is_stable=False)
sorted_logits = sorted_logits[:, ::-1]
sorted_logits_indices = sorted_logits_indices[:, ::-1]
sorted_logits_ = nn.softmax(sorted_logits)
sorted_logits_ = np.cumsum(sorted_logits_, axis=-1)
logits_ = np.where(sorted_logits_ < p, sorted_logits, np.NINF)
_, logits_ = jax.lax.sort_key_val(sorted_logits_indices, logits_, dimension=-1, is_stable=False)
return logits_
@jax.vmap
def mask_repetition(logits: np.ndarray, x: np.ndarray):
'''Mask all tokens that have occurred in the sentence.'''
return logits.at[x].set(np.NINF)
# Generation
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = FlaxGPT2LMHeadModel.from_pretrained(model_name)
mask_starts_with_s = np.array([(token[0] == 'Ġ' and len(token) >= 2 and token[1].lower() == 's') \
or (token[0] in string.ascii_letters) \
for token in tokenizer.convert_ids_to_tokens(range(tokenizer.vocab_size))], dtype=np.bool_)
def generation_step(x: np.ndarray, key: rand.KeyArray):
outputs = model(input_ids=x)
logits = outputs.logits[:, -1]
logits = only_top_p(logits)
logits = mask_repetition(logits, x)
logits = np.where(mask_starts_with_s, logits, np.NINF)
assert not np.any(np.all(np.isneginf(logits), axis=-1)), 'all tokens have been masked'
y = rand.categorical(key, logits, axis=-1)
return y
key = rand.PRNGKey(42)
sentences = ['some']
inputs = tokenizer(sentences, return_tensors='jax')
x = inputs.input_ids
for _ in range(6):
key, subkey = rand.split(key)
y = generation_step(x, subkey)
x = np.hstack((x, y[:, None]))
print(tokenizer.batch_decode(x)[0]) # some said she should stop selling seeds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment