Skip to content

Instantly share code, notes, and snippets.

@rish-16
Created May 24, 2019 10:46
Show Gist options
  • Select an option

  • Save rish-16/7293dfa24169d1a051f28ecdac133767 to your computer and use it in GitHub Desktop.

Select an option

Save rish-16/7293dfa24169d1a051f28ecdac133767 to your computer and use it in GitHub Desktop.
Generation function for the GPT-2 model
# Here, we have added a `raw_text` parameter
# From our Flask API we'll pass in the incoming text into this function
def interact_model(raw_text, model_name='117M', seed=None, nsamples=1, batch_size=1, length=None, temperature=1, top_k=40, models_dir='models/'):
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)
# We have created a new variable that stores the final text
final_text = ""
# We have removed the `raw_text` model prompt
# We'll directly parse the `raw_text` into the model
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
"""
Instead of printing the text, we'll concatenate
it to the `final_string` variable
"""
final_text += text
# Finally return the `final_text` variable as the response
return final_text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment