-
-
Save sar/d9ef872af2a1129887955d3ab88960b2 to your computer and use it in GitHub Desktop.
Fine-tune GPT-2 1558M on a Titan RTX using IBM Tensorflow Large Model Support v2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Usage: | |
# PYTHONPATH=src ./train --dataset <file|directory|glob> | |
# Got 1558M to train on a TITAN RTX using IBM Tensorflow_Large_Model_Support | |
# TLMS can insert explicit swaps in the graph between GPU and CPU, to extend the memory | |
# But the graph has While_Loop, so you have to use the TFLMSv2 version (which works with TF 1.x) | |
# | |
# Download, expand, get the egg out and install. | |
# must install IBM Large Model Support 2.0 for Tensorflow | |
# https://github.com/IBM/tensorflow-large-model-support | |
# https://www.ibm.com/support/knowledgecenter/SS5SF7_1.6.0/navigation/pai_getstarted_tflmsv2.html | |
# https://developer.ibm.com/linuxonpower/2019/06/11/tensorflow-large-model-support-resources/ | |
# direct link to TFLMS v2 egg | |
# https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda/linux-64/tensorflow-large-model-support-2.0.2-py36_970.gfa57a9e.tar.bz2 | |
# | |
# note extra switches for time based throttling, early termination, gpu use and re-arranged code for TFLMS | |
# | |
import argparse | |
import json | |
import os | |
import sys | |
import numpy as np | |
import tensorflow as tf | |
import time | |
import datetime | |
import tqdm | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
from tensorflow_large_model_support import LMS | |
import model | |
import sample | |
import encoder | |
from load_dataset import load_dataset, Sampler | |
from accumulate import AccumulatingOptimizer | |
import memory_saving_gradients | |
CHECKPOINT_DIR = 'checkpoint' | |
SAMPLE_DIR = 'samples' | |
parser = argparse.ArgumentParser( | |
description='Fine-tune GPT-2 on your custom dataset.', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).') | |
parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name') | |
parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size') | |
parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.') | |
parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size') | |
parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.0001, help='Learning rate for Adam') | |
parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.') | |
parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.') | |
parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.') | |
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. <adam|sgd>.') | |
parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.') | |
parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.') | |
parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.') | |
parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file') | |
parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/') | |
parser.add_argument('--run_until_count', metavar='N', type=int, default=sys.maxsize, help='exit program when counter reaches N steps') | |
parser.add_argument('--run_until_loss', metavar='N', type=float, default=-1.0, help='exit program when avg loss < N') | |
parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps') | |
parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens') | |
parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples') | |
parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps') | |
parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.') | |
parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.') | |
parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.') | |
parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.') | |
parser.add_argument('--use_gpu', metavar='STEPS', type=str, default='', help='Sets CUDA_VISIBLE_DEVICES for multiple-gpu.') | |
parser.add_argument('--use_throttle', type=str, default='', help='Hours to throttle (e.g. "9 18").') | |
parser.add_argument('--throttle_sleep', type=float, default=1.0, help='Time to sleep in seconds.') | |
parser.add_argument('--swapout_threshold', type=int, default=4096, help='LMS swapout_threshold') | |
parser.add_argument('--swapin_ahead', type=int, default=1024, help='LMS swapin_ahead') | |
parser.add_argument('--swapin_groupby', type=int, default=256, help='LMS swapin_groupby') | |
parser.add_argument('--sync_mode', type=int, default=0, help='LMS sync_mode') | |
def maketree(path): | |
try: | |
os.makedirs(path) | |
except: | |
pass | |
def randomize(context, hparams, p): | |
if p > 0: | |
mask = tf.random.uniform(shape=tf.shape(context)) < p | |
noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32) | |
return tf.where(mask, noise, context) | |
else: | |
return context | |
def main(): | |
args = parser.parse_args() | |
if (len(args.use_gpu)>0): | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.use_gpu #'1' | |
print('Set CUDA_VISIBLE_DEVICES =[{}]'.format(os.environ['CUDA_VISIBLE_DEVICES'])) | |
print('TF version: {}'.format(tf.__version__)) | |
if tf.test.gpu_device_name(): | |
print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) | |
else: | |
print("Please install GPU version of TF") | |
tf.logging.set_verbosity(tf.logging.INFO) | |
enc = encoder.get_encoder(args.model_name) | |
hparams = model.default_hparams() | |
with open(os.path.join('models', args.model_name, 'hparams.json')) as f: | |
hparams.override_from_dict(json.load(f)) | |
if args.sample_length > hparams.n_ctx: | |
raise ValueError( | |
"Can't get samples longer than window size: %s" % hparams.n_ctx) | |
if args.model_name == '345M': | |
args.memory_saving_gradients = True | |
if args.optimizer == 'adam': | |
args.only_train_transformer_layers = True | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF | |
with tf.name_scope('de_optimizer'): | |
if args.optimizer == 'adam': | |
opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) | |
elif args.optimizer == 'sgd': | |
opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) | |
else: | |
exit('Bad optimizer:', args.optimizer) | |
context = tf.placeholder(tf.int32, [args.batch_size, None]) | |
context_in = randomize(context, hparams, args.noise) | |
output = model.model(hparams=hparams, X=context_in) | |
loss = tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=context[:, 1:], logits=output['logits'][:, :-1])) | |
if args.val_every > 0: | |
val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) | |
val_output = model.model(hparams=hparams, X=val_context) | |
val_loss = tf.reduce_mean( | |
tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) | |
val_loss_summary = tf.summary.scalar('val_loss', val_loss) | |
tf_sample = sample.sample_sequence( | |
hparams=hparams, | |
length=args.sample_length, | |
context=context, | |
batch_size=args.batch_size, | |
temperature=1.0, | |
top_k=args.top_k, | |
top_p=args.top_p) | |
all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] | |
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars | |
#if args.optimizer == 'adam': | |
# opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) | |
#elif args.optimizer == 'sgd': | |
# opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) | |
#else: | |
# exit('Bad optimizer:', args.optimizer) | |
if args.accumulate_gradients > 1: | |
if args.memory_saving_gradients: | |
exit("Memory saving gradients are not implemented for gradient accumulation yet.") | |
opt = AccumulatingOptimizer( | |
opt=opt, | |
var_list=train_vars) | |
opt_reset = opt.reset() | |
opt_compute = opt.compute_gradients(loss) | |
opt_apply = opt.apply_gradients() | |
summary_loss = tf.summary.scalar('loss', opt_apply) | |
else: | |
if args.memory_saving_gradients: | |
opt_grads = memory_saving_gradients.gradients(loss, train_vars) | |
else: | |
opt_grads = tf.gradients(loss, train_vars) | |
opt_grads = list(zip(opt_grads, train_vars)) | |
opt_apply = opt.apply_gradients(opt_grads) | |
summary_loss = tf.summary.scalar('loss', loss) | |
summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) | |
summaries = tf.summary.merge([summary_lr, summary_loss]) | |
summary_log = tf.summary.FileWriter( | |
os.path.join(CHECKPOINT_DIR, args.run_name)) | |
saver = tf.train.Saver( | |
var_list=all_vars, | |
max_to_keep=5, | |
keep_checkpoint_every_n_hours=2) | |
print("BEGIN LMS") | |
# Enable Large Model Support | |
lms_model = LMS(swapout_threshold=args.swapout_threshold, | |
swapin_ahead=args.swapin_ahead, | |
swapin_groupby=args.swapin_groupby, | |
sync_mode=args.sync_mode) | |
lms_model.run(tf.get_default_graph()) | |
print("BEGIN SESS") | |
with tf.Session(config=config) as sess: | |
sess.run(tf.global_variables_initializer()) | |
if args.restore_from == 'latest': | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join(CHECKPOINT_DIR, args.run_name)) | |
if ckpt is None: | |
# Get fresh GPT weights if new run. | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join('models', args.model_name)) | |
elif args.restore_from == 'fresh': | |
ckpt = tf.train.latest_checkpoint( | |
os.path.join('models', args.model_name)) | |
else: | |
ckpt = tf.train.latest_checkpoint(args.restore_from) | |
print('Loading checkpoint', ckpt) | |
saver.restore(sess, ckpt) | |
print('Loading dataset...') | |
chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) | |
data_sampler = Sampler(chunks) | |
if args.val_every > 0: | |
if args.val_dataset: | |
val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) | |
else: | |
val_chunks = chunks | |
print('dataset has', data_sampler.total_size, 'tokens') | |
print('Training...') | |
if args.val_every > 0: | |
# Sample from validation set once with fixed seed to make | |
# it deterministic during training as well as across runs. | |
val_data_sampler = Sampler(val_chunks, seed=1) | |
val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] | |
for _ in range(args.val_batch_count)] | |
counter = 1 | |
counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') | |
if os.path.exists(counter_path): | |
# Load the step number if we're resuming a run | |
# Add 1 so we don't immediately try to save again | |
with open(counter_path, 'r') as fp: | |
counter = int(fp.read()) + 1 | |
def save(): | |
maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) | |
print( | |
'Saving', | |
os.path.join(CHECKPOINT_DIR, args.run_name, | |
'model-{}').format(counter)) | |
saver.save( | |
sess, | |
os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), | |
global_step=counter) | |
with open(counter_path, 'w') as fp: | |
fp.write(str(counter) + '\n') | |
def generate_samples(): | |
print('Generating samples...') | |
context_tokens = data_sampler.sample(1) | |
all_text = [] | |
index = 0 | |
while index < args.sample_num: | |
out = sess.run( | |
tf_sample, | |
feed_dict={context: args.batch_size * [context_tokens]}) | |
for i in range(min(args.sample_num - index, args.batch_size)): | |
text = enc.decode(out[i]) | |
text = '======== SAMPLE {} ========\n{}\n'.format( | |
index + 1, text) | |
all_text.append(text) | |
index += 1 | |
print(text) | |
maketree(os.path.join(SAMPLE_DIR, args.run_name)) | |
with open( | |
os.path.join(SAMPLE_DIR, args.run_name, | |
'samples-{}').format(counter), 'w') as fp: | |
fp.write('\n'.join(all_text)) | |
def validation(): | |
print('Calculating validation loss...') | |
losses = [] | |
for batch in tqdm.tqdm(val_batches): | |
losses.append(sess.run(val_loss, feed_dict={val_context: batch})) | |
v_val_loss = np.mean(losses) | |
v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) | |
summary_log.add_summary(v_summary, counter) | |
summary_log.flush() | |
print( | |
'[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' | |
.format( | |
counter=counter, | |
time=time.time() - start_time, | |
loss=v_val_loss)) | |
def sample_batch(): | |
return [data_sampler.sample(1024) for _ in range(args.batch_size)] | |
avg_loss = (0.0, 0.0) | |
start_time = time.time() | |
try: | |
while True: | |
if counter % args.save_every == 0: | |
save() | |
if counter % args.sample_every == 0: | |
generate_samples() | |
if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): | |
validation() | |
if args.accumulate_gradients > 1: | |
sess.run(opt_reset) | |
for _ in range(args.accumulate_gradients): | |
sess.run( | |
opt_compute, feed_dict={context: sample_batch()}) | |
(v_loss, v_summary) = sess.run((opt_apply, summary_loss)) | |
else: | |
(_, v_loss, v_summary) = sess.run( | |
(opt_apply, loss, summary_loss), | |
feed_dict={context: sample_batch()}) | |
summary_log.add_summary(v_summary, counter) | |
avg_loss = (avg_loss[0] * 0.99 + v_loss, | |
avg_loss[1] * 0.99 + 1.0) | |
print( | |
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' | |
.format( | |
counter=counter, | |
time=time.time() - start_time, | |
loss=v_loss, | |
avg=avg_loss[0] / avg_loss[1])) | |
counter += 1 | |
# KHC mods | |
avg_ls =avg_loss[0] / avg_loss[1] | |
if ((args.run_until_count != sys.maxsize) and ( counter % args.run_until_count == 0)): | |
save() | |
quit() | |
if (avg_ls < args.run_until_loss) : | |
save() | |
quit() | |
if (len(args.use_throttle)>1): | |
now = datetime.datetime.now() | |
throttle_times = args.use_throttle.split() | |
start_time = int(throttle_times[0]) | |
stop_time = int(throttle_times[1]) | |
now_time = now.hour | |
sleep_time = args.throttle_sleep | |
if ((start_time<now_time) and (now_time<stop_time)): | |
time.sleep(sleep_time) | |
print(' {} < {} < {} => Throttle {}'.format(start_time,now_time,stop_time,sleep_time)) | |
except KeyboardInterrupt: | |
print('interrupted') | |
save() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment