Skip to content

Instantly share code, notes, and snippets.

@yangkky
Last active January 18, 2023 18:37
Show Gist options
  • Save yangkky/f8deb5a36b884a4068a7207fc51fb8cf to your computer and use it in GitHub Desktop.
Save yangkky/f8deb5a36b884a4068a7207fc51fb8cf to your computer and use it in GitHub Desktop.
import argparse
import json
import os
from datetime import datetime
import pathlib
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from apex.optimizers import FusedAdam
from torch.utils.data import DataLoader, RandomSampler
import torch.distributed as dist
import torch.nn.functional as F
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
import numpy as np
from sequence_models.gnn import StructEncoderDecoder, cat_neighbors_nodes, BidirectionalStruct2SeqDecoder
from sequence_models.convolutional import ByteNetLM
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK, START, STOP
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from sequence_models.datasets import UniRefDataset, TRRDataset
from sequence_models.collaters import StructureCollater, MLMCollater, SimpleCollater
from sequence_models.losses import MaskedCrossEntropyLoss
from sequence_models.metrics import MaskedAccuracy
from sequence_models.utils import transformer_lr, Tokenizer
home = str(pathlib.Path.home())
def main():
parser = argparse.ArgumentParser()
parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/')
parser.add_argument('--task', default='mlm')
parser.add_argument('-w', '--weights_fpath', required=False)
parser.add_argument('-f', '--freeze', action='store_true')
parser.add_argument('--no_gnn', action='store_true')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
parser.add_argument('--esm', action='store_true')
parser.add_argument('--logsoftmax', action='store_true')
parser.add_argument('--full', action='store_true')
parser.add_argument('--test', action='store_true')
parser.add_argument('--test2', default=None)
parser.add_argument('--r', default=128, type=int)
parser.add_argument('--k', default=5, type=int)
parser.add_argument('--d_cnn', default=1280, type=int)
parser.add_argument('--n_cnn', default=56, type=int)
parser.add_argument('--activation', default='gelu')
parser.add_argument('--slim', default=False)
parser.add_argument('--dropout', default=0.0, type=float)
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('-off', '--offset', default=0, type=int,
help='Number of GPU devices to skip.')
parser.add_argument('--dataset', default=None)
args = parser.parse_args()
if args.esm:
args.freeze = True
args.world_size = args.gpus * args.nodes
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8881'
mp.spawn(train, nprocs=args.gpus, args=(args,))
def train(gpu, args):
_ = torch.manual_seed(0)
rank = args.nr * args.gpus + gpu
dist.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.world_size,
rank=rank)
torch.cuda.set_device(gpu + args.offset)
device = torch.device('cuda:' + str(gpu + args.offset))
n_tokens = len(PROTEIN_ALPHABET)
d_cnn = args.d_cnn
n_cnn_layers = args.n_cnn
kernel_size = args.k
r = args.r
slim = args.slim
activation = args.activation
node_features = 10
edge_features = 11
dropout = args.dropout
use_mpnn = True
n_structure_layers = 4
n_connections = 30
d_embed = 8
d_gnn = 256
bucket_size = 1000
max_tokens = 6000
max_batch_size = 100
epochs = 1000
lr = 1e-3
opt_level = 'O2'
warmup_steps = 1000
train_steps = 1e10
max_len = 1024
pad_idx = PROTEIN_ALPHABET.index(PAD)
start_idx = PROTEIN_ALPHABET.index(START)
stop_idx = PROTEIN_ALPHABET.index(STOP)
if args.esm or args.task == 'lm':
max_len -= 2
try:
data_dir = os.getenv('PT_DATA_DIR') + '/'
ptjob = True
except:
data_dir = home + '/data/'
ptjob = False
if args.dataset is not None:
dataset = args.dataset
elif ptjob:
dataset = 'cath'
else:
dataset = "uniclust/cath"
if args.task == 'mlm':
collater = MLMCollater(PROTEIN_ALPHABET)
else:
collater = SimpleCollater(PROTEIN_ALPHABET, pad=True, backwards=False)
if dataset != 'esm':
collater = StructureCollater(collater, n_connections=n_connections)
if dataset == 'esm':
data_dir = data_dir + 'esm/'
with open(data_dir + 'splits.json') as f:
splits = json.load(f)
metadata = np.load(data_dir + 'lengths_and_offsets.npz')
train_idx = splits['train']
len_train = np.minimum(metadata['ells'][train_idx], max_len)
ds_train = UniRefDataset(data_dir, 'train', structure=False, pdb=False,
p_drop=0.0, max_len=max_len)
elif dataset != 'trr':
data_dir = data_dir + dataset + '/'
with open(data_dir + 'splits.json') as f:
splits = json.load(f)
metadata = np.load(data_dir + 'lengths_and_offsets.npz')
train_idx = splits['train']
len_train = np.minimum(metadata['ells'][train_idx], max_len)
ds_train = UniRefDataset(data_dir, 'train', structure=True, pdb=True,
p_drop=0.0, max_len=max_len)
else:
ds_train = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'train', bin=False,
return_msa=False, max_len=max_len, untokenize=True)
len_train = np.load(data_dir + 'trrosetta/trrosetta/train_lengths.npz')['ells']
len_train = np.minimum(len_train, max_len)
train_sortish_sampler = SortishSampler(len_train, bucket_size, num_replicas=args.world_size, rank=rank)
train_sampler = ApproxBatchSampler(train_sortish_sampler, max_tokens, max_batch_size, len_train)
dl_train = DataLoader(dataset=ds_train,
batch_sampler=train_sampler,
num_workers=8,
collate_fn=collater)
if rank == 0:
if dataset == 'esm':
with open(data_dir + 'splits.json') as f:
splits = json.load(f)
metadata = np.load(data_dir + 'lengths_and_offsets.npz')
test_idx = splits['valid']
len_test = np.minimum(metadata['ells'][test_idx], max_len)
ds_test = UniRefDataset(data_dir, 'valid', structure=False, pdb=False,
p_drop=0.0, max_len=max_len)
test_sortish_sampler = RandomSampler(len_test)
max_tokens = 4 * max_tokens
test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test)
dl_test = DataLoader(dataset=ds_test,
batch_sampler=test_sampler,
num_workers=8,
collate_fn=collater)
ds_valid = ds_test
len_valid = len_test
elif dataset != 'trr':
with open(data_dir + 'splits.json') as f:
splits = json.load(f)
metadata = np.load(data_dir + 'lengths_and_offsets.npz')
valid_idx = splits['valid']
len_valid = np.minimum(metadata['ells'][valid_idx], max_len)
ds_valid = UniRefDataset(data_dir, 'valid', structure=True, pdb=True,
p_drop=0.0, max_len=max_len)
test_idx = splits['test']
len_test = np.minimum(metadata['ells'][test_idx], max_len)
ds_test = UniRefDataset(data_dir, 'test', structure=True, pdb=True,
p_drop=0.0, max_len=max_len)
if args.test2 is not None:
with open(home + '/workspace/data/esm/cath_splits.json') as f:
sp = json.load(f)
ds_test.idx = sp[args.test2]
len_test = np.minimum(metadata['ells'][ds_test.idx], max_len)
test_sortish_sampler = SortishSampler(len_test, bucket_size)
test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test)
dl_test = DataLoader(dataset=ds_test,
batch_sampler=test_sampler,
num_workers=8,
collate_fn=collater)
else:
ds_valid = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'valid', bin=False, return_msa=False,
max_len=max_len, untokenize=True)
len_valid = np.load(data_dir + 'trrosetta/trrosetta/valid_lengths.npz')['ells']
len_valid = np.minimum(len_valid, max_len)
dl_test = None
valid_sortish_sampler = SortishSampler(len_valid, bucket_size)
valid_sampler = ApproxBatchSampler(valid_sortish_sampler, max_tokens, max_batch_size, len_valid)
dl_valid = DataLoader(dataset=ds_valid,
batch_sampler=valid_sampler,
num_workers=8,
collate_fn=collater)
# Initiate model
if args.task == 'mlm':
decoder = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features,
d_gnn, num_decoder_layers=n_structure_layers,
dropout=dropout, use_mpnn=use_mpnn,
one_hot_src=False).to(device)
else:
decoder = StructEncoderDecoder(n_tokens, node_features, edge_features, d_gnn, src_node=True,
num_encoder_layers=n_structure_layers - 1, num_decoder_layers=1,
use_mpnn=use_mpnn, one_hot_src=False, dropout=dropout).to(device)
if args.esm:
from esm.pretrained import load_model_and_alphabet
encoder, alphabet = load_model_and_alphabet(home + "/.cache/torch/checkpoints/esm1b_t33_650M_UR50S.pt")
tokenizer = Tokenizer(PROTEIN_ALPHABET)
to_esm = {}
for p in PROTEIN_ALPHABET:
k = tokenizer.a_to_t[p]
if p == PAD:
to_esm[k] = alphabet.padding_idx
elif p in alphabet.tok_to_idx:
to_esm[k] = alphabet.tok_to_idx[p]
elif p == MASK:
to_esm[k] = alphabet.mask_idx
elif p == START:
to_esm[k] = alphabet.cls_idx
elif p == STOP:
to_esm[k] = alphabet.eos_idx
else:
to_esm[k] = alphabet.unk_idx
dim_reorder = torch.tensor([to_esm[k] for k in range(len(PROTEIN_ALPHABET))])
encoder = encoder.to(device)
else:
if args.task == 'mlm':
causal = False
else:
causal = True
encoder = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True,
slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device)
if args.weights_fpath is not None:
# if ptjob:
# args.weights_fpath = os.getenv('PT_DATA_DIR') + '/' + args.weights_fpath
print('Loading weights from ' + args.weights_fpath + '...')
sd = torch.load(args.weights_fpath, map_location=device)
# if args.no_gnn:
cnn_sd = sd['model_state_dict']
cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
encoder.load_state_dict(cnn_sd)
# else:
# gnn_sd = sd['decoder_state_dict']
# decoder.load_state_dict(gnn_sd)
# cnn_sd = sd['encoder_state_dict']
# cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()}
# encoder.load_state_dict(cnn_sd)
if args.esm:
optimizer = FusedAdam(list(decoder.parameters()), lr=lr)
decoder, optimizer = amp.initialize(decoder, optimizer, opt_level=opt_level)
decoder = DDP(decoder)
else:
if args.freeze:
optimizer = FusedAdam(list(decoder.parameters()), lr=lr)
(encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level)
else:
optimizer = FusedAdam([{'params': encoder.parameters(), 'lr': 1e-4},
{'params': decoder.parameters(), 'lr': lr}])
(encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level)
decoder = DDP(decoder)
encoder = DDP(encoder)
scheduler = LambdaLR(optimizer, transformer_lr(warmup_steps))
loss_func = MaskedCrossEntropyLoss()
accu_func = MaskedAccuracy()
def epoch(encoder, decoder, train, current_step=0):
start_time = datetime.now()
if train:
if not args.freeze:
encoder = encoder.train()
if args.freeze:
encoder = encoder.train()
if decoder is not None:
decoder = decoder.train()
loader = dl_train
t = 'Training:'
else:
encoder = encoder.eval()
if decoder is not None:
decoder = decoder.eval()
loader = dl_valid
t = 'Validating:'
losses = []
accus = []
ns = []
chunk_time = datetime.now()
n_seen = 0
if train:
n_total = len(ds_train) // args.world_size
else:
n_total = len(ds_valid)
for i, batch in enumerate(loader):
new_loss, new_accu, new_n = step(encoder, decoder, batch, train)
losses.append(new_loss * new_n)
accus.append(new_accu * new_n)
ns.append(new_n)
n_seen += len(batch[0])
total_n = sum(ns)
rloss = sum(losses) / total_n
raccu = sum(accus) / total_n
if train:
nsteps = current_step + i + 1
else:
nsteps = i
print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %.4f accu = %.4f'
% (t, e + 1, epochs, nsteps, n_seen, n_total, rloss, raccu),
end='')
if train:
losses = losses[-999:]
accus = accus[-999:]
ns = ns[-999:]
# if (nsteps) % train_steps == 0 and rank == 0:
# print('\nTraining complete in ' + str(datetime.now() - chunk_time))
# with torch.no_grad():
# _ = epoch(encoder, decoder, False, current_step=nsteps)
# chunk_time = datetime.now()
if not train:
print('\nValidation complete in ' + str(datetime.now() - start_time))
return rloss
elif rank == 0:
print('\nEpoch complete in ' + str(datetime.now() - start_time))
return i
def step(encoder, decoder, batch, train):
if args.dataset == 'esm':
src, tgt, mask = batch
src = src.to(device)
tgt = tgt.to(device)
mask = mask.to(device)
else:
if args.task == 'lm':
src, nodes, edges, connections, edge_mask = batch
tgt = src.detach().clone()
mask = (src != PROTEIN_ALPHABET.index(PAD)).float()
n, ell = src.shape
starts = torch.zeros(n, 1) + start_idx
starts = starts.long()
src = torch.cat([starts, src], dim=-1)
else:
src, tgt, mask, nodes, edges, connections, edge_mask = batch
src = src.to(device)
tgt = tgt.to(device)
mask = mask.to(device)
nodes = nodes.cuda()
edges = edges.cuda()
connections = connections.cuda()
edge_mask = edge_mask.cuda()
input_mask = (src != pad_idx).float().unsqueeze(-1)
if args.esm:
n, ell = src.shape
esm_src = torch.zeros(n, ell + 2) + alphabet.padding_idx
esm_src[:, 0] = alphabet.cls_idx
tokenized = [[to_esm[s.item()] for s in sr if s != pad_idx] + [alphabet.eos_idx] for sr in src]
ells = []
for i, t in enumerate(tokenized):
el = len(t)
ells.append(el - 1)
esm_src[i, 1:el + 1] = torch.tensor(t)
esm_src = esm_src.to(device).long()
with torch.no_grad():
e = encoder(esm_src)['logits']
embeddings = torch.zeros(n, ell, 33)
embeddings = embeddings.to(device)
for i, (ee, el) in enumerate(zip(e, ells)):
embeddings[i, :el, :] = ee[1:el + 1]
embeddings = embeddings[:, :, dim_reorder]
else:
if args.freeze:
with torch.no_grad():
embeddings = encoder(src, input_mask=input_mask)
else:
embeddings = encoder(src, input_mask=input_mask)
# slice out the extra position
if args.task == 'lm':
embeddings = embeddings[:, :-1, :]
if (args.esm and args.test) or args.no_gnn:
outputs = embeddings
else:
if args.logsoftmax:
embeddings = F.log_softmax(embeddings, dim=-1)
outputs = decoder(nodes, edges, connections, embeddings, edge_mask)
loss = loss_func(outputs, tgt, mask)
accu = accu_func(outputs, tgt, mask)
if train:
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
scheduler.step()
return loss.item(), accu.item(), mask.sum().item()
total_steps = 0
n_parameters = sum(p.numel() for p in decoder.parameters())
if rank == 0:
print('%d model parameters' %n_parameters)
print('%d training sequences' %len(len_train))
print('%d validation sequences' %len(len_valid))
if args.test:
e = 0
with torch.no_grad():
dl_valid = dl_test
_ = epoch(encoder, decoder, False)
return
best_valid_loss = 100
patience = 20
min_epochs = 500
waiting = 0
m_file = args.out_fpath + 'metrics.csv'
best_path = args.out_fpath + 'best.pt'
for e in range(epochs):
train_sortish_sampler.set_epoch(e)
total_steps += epoch(encoder, decoder, True, current_step=total_steps)
if rank == 0:
nsteps = total_steps
model_path = args.out_fpath + 'checkpoint%d.tar' % nsteps
torch.save({
'step': nsteps,
'encoder_state_dict': encoder.state_dict(),
'decoder_state_dict': decoder.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, model_path)
with torch.no_grad():
vloss = epoch(encoder, decoder, False, current_step=total_steps)
with open(m_file, 'a') as f:
f.write(str(nsteps))
f.write(',')
f.write(str(vloss))
f.write('\n')
if vloss < best_valid_loss:
best_valid_loss = vloss
waiting = 0
best_epoch = e + 1
best_path = model_path
# save_me = {'step': nsteps, 'valid_loss': vloss, 'epoch': e + 1}
# save_me['encoder_state_dict'] = encoder.state_dict()
# save_me['decoder_state_dict'] = decoder.state_dict()
# save_me['optimizer_state_dict'] = optimizer.state_dict()
# torch.save(save_me, best_path)
else:
waiting += 1
if waiting >= patience and e > min_epochs:
break
if rank == 0 and dl_test is not None:
print('Loading checkpoint from epoch %d and testing...' %best_epoch)
sd = torch.load(best_path)
if not args.esm:
encoder.load_state_dict(sd['encoder_state_dict'])
encoder = encoder.eval()
decoder.load_state_dict(sd['decoder_state_dict'])
decoder = decoder.eval()
with torch.no_grad():
dl_valid = dl_test
_ = epoch(encoder, decoder, False)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment