Created
May 24, 2019 10:46
-
-
Save rish-16/7293dfa24169d1a051f28ecdac133767 to your computer and use it in GitHub Desktop.
Generation function for the GPT-2 model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Here, we have added a `raw_text` parameter | |
| # From our Flask API we'll pass in the incoming text into this function | |
| def interact_model(raw_text, model_name='117M', seed=None, nsamples=1, batch_size=1, length=None, temperature=1, top_k=40, models_dir='models/'): | |
| models_dir = os.path.expanduser(os.path.expandvars(models_dir)) | |
| if batch_size is None: | |
| batch_size = 1 | |
| assert nsamples % batch_size == 0 | |
| enc = encoder.get_encoder(model_name, models_dir) | |
| hparams = model.default_hparams() | |
| with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: | |
| hparams.override_from_dict(json.load(f)) | |
| if length is None: | |
| length = hparams.n_ctx // 2 | |
| elif length > hparams.n_ctx: | |
| raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) | |
| with tf.Session(graph=tf.Graph()) as sess: | |
| context = tf.placeholder(tf.int32, [batch_size, None]) | |
| np.random.seed(seed) | |
| tf.set_random_seed(seed) | |
| output = sample.sample_sequence( | |
| hparams=hparams, length=length, | |
| context=context, | |
| batch_size=batch_size, | |
| temperature=temperature, top_k=top_k | |
| ) | |
| saver = tf.train.Saver() | |
| ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) | |
| saver.restore(sess, ckpt) | |
| # We have created a new variable that stores the final text | |
| final_text = "" | |
| # We have removed the `raw_text` model prompt | |
| # We'll directly parse the `raw_text` into the model | |
| context_tokens = enc.encode(raw_text) | |
| generated = 0 | |
| for _ in range(nsamples // batch_size): | |
| out = sess.run(output, feed_dict={ | |
| context: [context_tokens for _ in range(batch_size)] | |
| })[:, len(context_tokens):] | |
| for i in range(batch_size): | |
| generated += 1 | |
| text = enc.decode(out[i]) | |
| """ | |
| Instead of printing the text, we'll concatenate | |
| it to the `final_string` variable | |
| """ | |
| final_text += text | |
| # Finally return the `final_text` variable as the response | |
| return final_text |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment