Created
May 22, 2019 05:11
-
-
Save rossgoodwin/8522745599debbea9b2106caafad78a9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Usage: | |
# PYTHONPATH=src ./train --dataset <file|directory|glob> | |
import argparse | |
import json | |
import os | |
import numpy as np | |
import tensorflow as tf | |
import time | |
import tqdm | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
import horovod.tensorflow as hvd | |
import model, sample, encoder | |
from load_dataset import load_dataset, Sampler | |
from accumulate import AccumulatingOptimizer | |
import memory_saving_gradients | |
CHECKPOINT_DIR = 'checkpoint' | |
SAMPLE_DIR = 'samples' | |
parser = argparse.ArgumentParser( | |
description='Fine-tune GPT-2 on your custom dataset.', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).') | |
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name') | |
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size') | |
parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size') | |
parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.0001, help='Learning rate for Adam') | |
parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.') | |
parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.') | |
parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.') | |
parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file') | |
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/') | |
parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps') | |
parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens') | |
parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples') | |
parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps') | |
parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.') | |
parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.') | |
parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.') | |
parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.') | |
hvd.init() | |
def maketree(path): | |
try: | |
os.makedirs(path) | |
except: | |
pass | |
def main(): | |
args = parser.parse_args() | |
enc = encoder.get_encoder(args.model_name) | |
hparams = model.default_hparams() | |
with open(os.path.join('models', args.model_name, 'hparams.json')) as f: | |
hparams.override_from_dict(json.load(f)) | |
if args.sample_length > hparams.n_ctx: | |
raise ValueError( | |
"Can't get samples longer than window size: %s" % hparams.n_ctx) | |
if args.model_name == '345M': | |
args.memory_saving_gradients = True | |
args.only_train_transformer_layers = True | |
config = tf.ConfigProto() | |
config.gpu_options.visible_device_list = str(hvd.local_rank()) | |
config.gpu_options.allow_growth = True | |
config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF | |
with tf.Session(config=config) as sess: | |
context = tf.placeholder(tf.int32, [args.batch_size, None]) | |
output = model.model(hparams=hparams, X=context) | |
loss = tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=context[:, 1:], logits=output['logits'][:, :-1])) | |
if args.val_every > 0: | |
val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) | |
val_output = model.model(hparams=hparams, X=val_context) | |
val_loss = tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) | |
val_loss_summary = tf.summary.scalar('val_loss', val_loss) | |
tf_sample = sample.sample_sequence( | |
hparams=hparams, | |
length=args.sample_length, | |
context=context, | |
batch_size=args.batch_size, | |
temperature=1.0, | |
top_k=40) | |
all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] | |
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars | |
if args.accumulate_gradients > 1: | |
if args.memory_saving_gradients: | |
exit("Memory saving gradients are not implemented for gradient accumulation yet.") | |
opt = AccumulatingOptimizer( | |
opt=tf.train.AdamOptimizer(learning_rate=args.learning_rate), | |
var_list=train_vars) | |
opt = hvd.DistributedOptimizer(opt) | |
opt_reset = opt.reset() | |
opt_compute = opt.compute_gradients(loss) | |
opt_apply = opt.apply_gradients() | |
summary_loss = tf.summary.scalar('loss', opt_apply) | |
else: | |
opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) | |
opt = hvd.DistributedOptimizer(opt) | |
if args.memory_saving_gradients: | |
opt_grads = memory_saving_gradients.gradients(loss, train_vars) | |
else: | |
opt_grads = tf.gradients(loss, train_vars) | |
opt_grads = list(zip(opt_grads, train_vars)) | |
opt_apply = opt.apply_gradients(opt_grads) | |
summary_loss = tf.summary.scalar('loss', loss) | |
bcast = hvd.broadcast_global_variables(0) | |
summary_log = tf.summary.FileWriter( | |
os.path.join(CHECKPOINT_DIR, args.run_name)) | |
saver = tf.train.Saver( | |
var_list=all_vars, | |
max_to_keep=5, | |
keep_checkpoint_every_n_hours=2) | |
sess.run(tf.global_variables_initializer()) | |
if args.restore_from == 'latest': | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join(CHECKPOINT_DIR, args.run_name)) | |
if ckpt is None: | |
# Get fresh GPT weights if new run. | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join('models', args.model_name)) | |
elif args.restore_from == 'fresh': | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join('models', args.model_name)) | |
else: | |
ckpt = tf.train.latest_checkpoint(args.restore_from) | |
print(str(hvd.local_rank()), 'Loading checkpoint', ckpt) | |
saver.restore(sess, ckpt) | |
bcast.run() | |
print(str(hvd.local_rank()), 'Loading dataset...') | |
chunks = load_dataset(enc, args.dataset, args.combine) | |
data_sampler = Sampler(chunks) | |
if args.val_every > 0: | |
val_chunks = load_dataset(enc, args.val_dataset, args.combine) if args.val_dataset else chunks | |
print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens') | |
print(str(hvd.local_rank()), 'Training...') | |
if args.val_every > 0: | |
# Sample from validation set once with fixed seed to make | |
# it deterministic during training as well as across runs. | |
val_data_sampler = Sampler(val_chunks, seed=1) | |
val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] | |
for _ in range(args.val_batch_count)] | |
counter = 1 | |
counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') | |
if os.path.exists(counter_path): | |
# Load the step number if we're resuming a run | |
# Add 1 so we don't immediately try to save again | |
with open(counter_path, 'r') as fp: | |
counter = int(fp.read()) + 1 | |
def save(): | |
maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) | |
print( | |
'Saving', | |
os.path.join(CHECKPOINT_DIR, args.run_name, | |
'model-{}').format(counter)) | |
saver.save( | |
sess, | |
os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), | |
global_step=counter) | |
with open(counter_path, 'w') as fp: | |
fp.write(str(counter) + '\n') | |
def generate_samples(): | |
print('Generating samples...') | |
context_tokens = data_sampler.sample(1) | |
all_text = [] | |
index = 0 | |
while index < args.sample_num: | |
out = sess.run( | |
tf_sample, | |
feed_dict={context: args.batch_size * [context_tokens]}) | |
for i in range(min(args.sample_num - index, args.batch_size)): | |
text = enc.decode(out[i]) | |
text = '======== SAMPLE {} ========\n{}\n'.format( | |
index + 1, text) | |
all_text.append(text) | |
index += 1 | |
print(text) | |
maketree(os.path.join(SAMPLE_DIR, args.run_name)) | |
with open( | |
os.path.join(SAMPLE_DIR, args.run_name, | |
'samples-{}').format(counter), 'w') as fp: | |
fp.write('\n'.join(all_text)) | |
def validation(): | |
print('Calculating validation loss...') | |
losses = [] | |
for batch in tqdm.tqdm(val_batches): | |
losses.append(sess.run(val_loss, feed_dict={val_context: batch})) | |
v_val_loss = np.mean(losses) | |
v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) | |
summary_log.add_summary(v_summary, counter) | |
summary_log.flush() | |
print( | |
'[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' | |
.format( | |
counter=counter, | |
time=time.time() - start_time, | |
loss=v_val_loss)) | |
def sample_batch(): | |
return [data_sampler.sample(1024) for _ in range(args.batch_size)] | |
avg_loss = (0.0, 0.0) | |
start_time = time.time() | |
try: | |
while True: | |
if hvd.rank() == 0: | |
if counter % args.save_every == 0: | |
save() | |
if counter % args.sample_every == 0: | |
generate_samples() | |
if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): | |
validation() | |
if args.accumulate_gradients > 1: | |
sess.run(opt_reset) | |
for _ in range(args.accumulate_gradients): | |
sess.run( | |
opt_compute, feed_dict={context: sample_batch()}) | |
(v_loss, v_summary) = sess.run((opt_apply, summary_loss)) | |
else: | |
(_, v_loss, v_summary) = sess.run( | |
(opt_apply, loss, summary_loss), | |
feed_dict={context: sample_batch()}) | |
summary_log.add_summary(v_summary, counter) | |
avg_loss = (avg_loss[0] * 0.99 + v_loss, | |
avg_loss[1] * 0.99 + 1.0) | |
if hvd.rank() == 0: | |
print( | |
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' | |
.format( | |
counter=counter, | |
time=time.time() - start_time, | |
loss=v_loss, | |
avg=avg_loss[0] / avg_loss[1])) | |
counter += 1 | |
except KeyboardInterrupt: | |
print('interrupted') | |
if hvd.rank() == 0: | |
save() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment