Last active
January 18, 2023 18:37
-
-
Save yangkky/f8deb5a36b884a4068a7207fc51fb8cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
from datetime import datetime | |
import pathlib | |
import torch | |
import torch.multiprocessing as mp | |
from torch.optim import Adam | |
from torch.optim.lr_scheduler import LambdaLR | |
from apex.optimizers import FusedAdam | |
from torch.utils.data import DataLoader, RandomSampler | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from apex.parallel import DistributedDataParallel as DDP | |
from apex import amp | |
import numpy as np | |
from sequence_models.gnn import StructEncoderDecoder, cat_neighbors_nodes, BidirectionalStruct2SeqDecoder | |
from sequence_models.convolutional import ByteNetLM | |
from sequence_models.constants import PROTEIN_ALPHABET, PAD, MASK, START, STOP | |
from sequence_models.samplers import SortishSampler, ApproxBatchSampler | |
from sequence_models.datasets import UniRefDataset, TRRDataset | |
from sequence_models.collaters import StructureCollater, MLMCollater, SimpleCollater | |
from sequence_models.losses import MaskedCrossEntropyLoss | |
from sequence_models.metrics import MaskedAccuracy | |
from sequence_models.utils import transformer_lr, Tokenizer | |
home = str(pathlib.Path.home()) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('out_fpath', type=str, nargs='?', default=os.getenv('PT_OUTPUT_DIR', '/tmp') + '/') | |
parser.add_argument('--task', default='mlm') | |
parser.add_argument('-w', '--weights_fpath', required=False) | |
parser.add_argument('-f', '--freeze', action='store_true') | |
parser.add_argument('--no_gnn', action='store_true') | |
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') | |
parser.add_argument('--esm', action='store_true') | |
parser.add_argument('--logsoftmax', action='store_true') | |
parser.add_argument('--full', action='store_true') | |
parser.add_argument('--test', action='store_true') | |
parser.add_argument('--test2', default=None) | |
parser.add_argument('--r', default=128, type=int) | |
parser.add_argument('--k', default=5, type=int) | |
parser.add_argument('--d_cnn', default=1280, type=int) | |
parser.add_argument('--n_cnn', default=56, type=int) | |
parser.add_argument('--activation', default='gelu') | |
parser.add_argument('--slim', default=False) | |
parser.add_argument('--dropout', default=0.0, type=float) | |
parser.add_argument('-g', '--gpus', default=1, type=int, | |
help='number of gpus per node') | |
parser.add_argument('-nr', '--nr', default=0, type=int, | |
help='ranking within the nodes') | |
parser.add_argument('-off', '--offset', default=0, type=int, | |
help='Number of GPU devices to skip.') | |
parser.add_argument('--dataset', default=None) | |
args = parser.parse_args() | |
if args.esm: | |
args.freeze = True | |
args.world_size = args.gpus * args.nodes | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '8881' | |
mp.spawn(train, nprocs=args.gpus, args=(args,)) | |
def train(gpu, args): | |
_ = torch.manual_seed(0) | |
rank = args.nr * args.gpus + gpu | |
dist.init_process_group( | |
backend='nccl', | |
init_method='env://', | |
world_size=args.world_size, | |
rank=rank) | |
torch.cuda.set_device(gpu + args.offset) | |
device = torch.device('cuda:' + str(gpu + args.offset)) | |
n_tokens = len(PROTEIN_ALPHABET) | |
d_cnn = args.d_cnn | |
n_cnn_layers = args.n_cnn | |
kernel_size = args.k | |
r = args.r | |
slim = args.slim | |
activation = args.activation | |
node_features = 10 | |
edge_features = 11 | |
dropout = args.dropout | |
use_mpnn = True | |
n_structure_layers = 4 | |
n_connections = 30 | |
d_embed = 8 | |
d_gnn = 256 | |
bucket_size = 1000 | |
max_tokens = 6000 | |
max_batch_size = 100 | |
epochs = 1000 | |
lr = 1e-3 | |
opt_level = 'O2' | |
warmup_steps = 1000 | |
train_steps = 1e10 | |
max_len = 1024 | |
pad_idx = PROTEIN_ALPHABET.index(PAD) | |
start_idx = PROTEIN_ALPHABET.index(START) | |
stop_idx = PROTEIN_ALPHABET.index(STOP) | |
if args.esm or args.task == 'lm': | |
max_len -= 2 | |
try: | |
data_dir = os.getenv('PT_DATA_DIR') + '/' | |
ptjob = True | |
except: | |
data_dir = home + '/data/' | |
ptjob = False | |
if args.dataset is not None: | |
dataset = args.dataset | |
elif ptjob: | |
dataset = 'cath' | |
else: | |
dataset = "uniclust/cath" | |
if args.task == 'mlm': | |
collater = MLMCollater(PROTEIN_ALPHABET) | |
else: | |
collater = SimpleCollater(PROTEIN_ALPHABET, pad=True, backwards=False) | |
if dataset != 'esm': | |
collater = StructureCollater(collater, n_connections=n_connections) | |
if dataset == 'esm': | |
data_dir = data_dir + 'esm/' | |
with open(data_dir + 'splits.json') as f: | |
splits = json.load(f) | |
metadata = np.load(data_dir + 'lengths_and_offsets.npz') | |
train_idx = splits['train'] | |
len_train = np.minimum(metadata['ells'][train_idx], max_len) | |
ds_train = UniRefDataset(data_dir, 'train', structure=False, pdb=False, | |
p_drop=0.0, max_len=max_len) | |
elif dataset != 'trr': | |
data_dir = data_dir + dataset + '/' | |
with open(data_dir + 'splits.json') as f: | |
splits = json.load(f) | |
metadata = np.load(data_dir + 'lengths_and_offsets.npz') | |
train_idx = splits['train'] | |
len_train = np.minimum(metadata['ells'][train_idx], max_len) | |
ds_train = UniRefDataset(data_dir, 'train', structure=True, pdb=True, | |
p_drop=0.0, max_len=max_len) | |
else: | |
ds_train = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'train', bin=False, | |
return_msa=False, max_len=max_len, untokenize=True) | |
len_train = np.load(data_dir + 'trrosetta/trrosetta/train_lengths.npz')['ells'] | |
len_train = np.minimum(len_train, max_len) | |
train_sortish_sampler = SortishSampler(len_train, bucket_size, num_replicas=args.world_size, rank=rank) | |
train_sampler = ApproxBatchSampler(train_sortish_sampler, max_tokens, max_batch_size, len_train) | |
dl_train = DataLoader(dataset=ds_train, | |
batch_sampler=train_sampler, | |
num_workers=8, | |
collate_fn=collater) | |
if rank == 0: | |
if dataset == 'esm': | |
with open(data_dir + 'splits.json') as f: | |
splits = json.load(f) | |
metadata = np.load(data_dir + 'lengths_and_offsets.npz') | |
test_idx = splits['valid'] | |
len_test = np.minimum(metadata['ells'][test_idx], max_len) | |
ds_test = UniRefDataset(data_dir, 'valid', structure=False, pdb=False, | |
p_drop=0.0, max_len=max_len) | |
test_sortish_sampler = RandomSampler(len_test) | |
max_tokens = 4 * max_tokens | |
test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test) | |
dl_test = DataLoader(dataset=ds_test, | |
batch_sampler=test_sampler, | |
num_workers=8, | |
collate_fn=collater) | |
ds_valid = ds_test | |
len_valid = len_test | |
elif dataset != 'trr': | |
with open(data_dir + 'splits.json') as f: | |
splits = json.load(f) | |
metadata = np.load(data_dir + 'lengths_and_offsets.npz') | |
valid_idx = splits['valid'] | |
len_valid = np.minimum(metadata['ells'][valid_idx], max_len) | |
ds_valid = UniRefDataset(data_dir, 'valid', structure=True, pdb=True, | |
p_drop=0.0, max_len=max_len) | |
test_idx = splits['test'] | |
len_test = np.minimum(metadata['ells'][test_idx], max_len) | |
ds_test = UniRefDataset(data_dir, 'test', structure=True, pdb=True, | |
p_drop=0.0, max_len=max_len) | |
if args.test2 is not None: | |
with open(home + '/workspace/data/esm/cath_splits.json') as f: | |
sp = json.load(f) | |
ds_test.idx = sp[args.test2] | |
len_test = np.minimum(metadata['ells'][ds_test.idx], max_len) | |
test_sortish_sampler = SortishSampler(len_test, bucket_size) | |
test_sampler = ApproxBatchSampler(test_sortish_sampler, max_tokens, max_batch_size, len_test) | |
dl_test = DataLoader(dataset=ds_test, | |
batch_sampler=test_sampler, | |
num_workers=8, | |
collate_fn=collater) | |
else: | |
ds_valid = TRRDataset(data_dir + '/trrosetta/trrosetta/', 'valid', bin=False, return_msa=False, | |
max_len=max_len, untokenize=True) | |
len_valid = np.load(data_dir + 'trrosetta/trrosetta/valid_lengths.npz')['ells'] | |
len_valid = np.minimum(len_valid, max_len) | |
dl_test = None | |
valid_sortish_sampler = SortishSampler(len_valid, bucket_size) | |
valid_sampler = ApproxBatchSampler(valid_sortish_sampler, max_tokens, max_batch_size, len_valid) | |
dl_valid = DataLoader(dataset=ds_valid, | |
batch_sampler=valid_sampler, | |
num_workers=8, | |
collate_fn=collater) | |
# Initiate model | |
if args.task == 'mlm': | |
decoder = BidirectionalStruct2SeqDecoder(n_tokens, node_features, edge_features, | |
d_gnn, num_decoder_layers=n_structure_layers, | |
dropout=dropout, use_mpnn=use_mpnn, | |
one_hot_src=False).to(device) | |
else: | |
decoder = StructEncoderDecoder(n_tokens, node_features, edge_features, d_gnn, src_node=True, | |
num_encoder_layers=n_structure_layers - 1, num_decoder_layers=1, | |
use_mpnn=use_mpnn, one_hot_src=False, dropout=dropout).to(device) | |
if args.esm: | |
from esm.pretrained import load_model_and_alphabet | |
encoder, alphabet = load_model_and_alphabet(home + "/.cache/torch/checkpoints/esm1b_t33_650M_UR50S.pt") | |
tokenizer = Tokenizer(PROTEIN_ALPHABET) | |
to_esm = {} | |
for p in PROTEIN_ALPHABET: | |
k = tokenizer.a_to_t[p] | |
if p == PAD: | |
to_esm[k] = alphabet.padding_idx | |
elif p in alphabet.tok_to_idx: | |
to_esm[k] = alphabet.tok_to_idx[p] | |
elif p == MASK: | |
to_esm[k] = alphabet.mask_idx | |
elif p == START: | |
to_esm[k] = alphabet.cls_idx | |
elif p == STOP: | |
to_esm[k] = alphabet.eos_idx | |
else: | |
to_esm[k] = alphabet.unk_idx | |
dim_reorder = torch.tensor([to_esm[k] for k in range(len(PROTEIN_ALPHABET))]) | |
encoder = encoder.to(device) | |
else: | |
if args.task == 'mlm': | |
causal = False | |
else: | |
causal = True | |
encoder = ByteNetLM(n_tokens, d_embed, d_cnn, n_cnn_layers, kernel_size, r, final_ln=True, | |
slim=slim, activation=activation, causal=causal, padding_idx=pad_idx).to(device) | |
if args.weights_fpath is not None: | |
# if ptjob: | |
# args.weights_fpath = os.getenv('PT_DATA_DIR') + '/' + args.weights_fpath | |
print('Loading weights from ' + args.weights_fpath + '...') | |
sd = torch.load(args.weights_fpath, map_location=device) | |
# if args.no_gnn: | |
cnn_sd = sd['model_state_dict'] | |
cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()} | |
encoder.load_state_dict(cnn_sd) | |
# else: | |
# gnn_sd = sd['decoder_state_dict'] | |
# decoder.load_state_dict(gnn_sd) | |
# cnn_sd = sd['encoder_state_dict'] | |
# cnn_sd = {k.split('module.')[1]: v for k, v in cnn_sd.items()} | |
# encoder.load_state_dict(cnn_sd) | |
if args.esm: | |
optimizer = FusedAdam(list(decoder.parameters()), lr=lr) | |
decoder, optimizer = amp.initialize(decoder, optimizer, opt_level=opt_level) | |
decoder = DDP(decoder) | |
else: | |
if args.freeze: | |
optimizer = FusedAdam(list(decoder.parameters()), lr=lr) | |
(encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level) | |
else: | |
optimizer = FusedAdam([{'params': encoder.parameters(), 'lr': 1e-4}, | |
{'params': decoder.parameters(), 'lr': lr}]) | |
(encoder, decoder), optimizer = amp.initialize([encoder, decoder], optimizer, opt_level=opt_level) | |
decoder = DDP(decoder) | |
encoder = DDP(encoder) | |
scheduler = LambdaLR(optimizer, transformer_lr(warmup_steps)) | |
loss_func = MaskedCrossEntropyLoss() | |
accu_func = MaskedAccuracy() | |
def epoch(encoder, decoder, train, current_step=0): | |
start_time = datetime.now() | |
if train: | |
if not args.freeze: | |
encoder = encoder.train() | |
if args.freeze: | |
encoder = encoder.train() | |
if decoder is not None: | |
decoder = decoder.train() | |
loader = dl_train | |
t = 'Training:' | |
else: | |
encoder = encoder.eval() | |
if decoder is not None: | |
decoder = decoder.eval() | |
loader = dl_valid | |
t = 'Validating:' | |
losses = [] | |
accus = [] | |
ns = [] | |
chunk_time = datetime.now() | |
n_seen = 0 | |
if train: | |
n_total = len(ds_train) // args.world_size | |
else: | |
n_total = len(ds_valid) | |
for i, batch in enumerate(loader): | |
new_loss, new_accu, new_n = step(encoder, decoder, batch, train) | |
losses.append(new_loss * new_n) | |
accus.append(new_accu * new_n) | |
ns.append(new_n) | |
n_seen += len(batch[0]) | |
total_n = sum(ns) | |
rloss = sum(losses) / total_n | |
raccu = sum(accus) / total_n | |
if train: | |
nsteps = current_step + i + 1 | |
else: | |
nsteps = i | |
print('\r%s Epoch %d of %d Step %d Example %d of %d loss = %.4f accu = %.4f' | |
% (t, e + 1, epochs, nsteps, n_seen, n_total, rloss, raccu), | |
end='') | |
if train: | |
losses = losses[-999:] | |
accus = accus[-999:] | |
ns = ns[-999:] | |
# if (nsteps) % train_steps == 0 and rank == 0: | |
# print('\nTraining complete in ' + str(datetime.now() - chunk_time)) | |
# with torch.no_grad(): | |
# _ = epoch(encoder, decoder, False, current_step=nsteps) | |
# chunk_time = datetime.now() | |
if not train: | |
print('\nValidation complete in ' + str(datetime.now() - start_time)) | |
return rloss | |
elif rank == 0: | |
print('\nEpoch complete in ' + str(datetime.now() - start_time)) | |
return i | |
def step(encoder, decoder, batch, train): | |
if args.dataset == 'esm': | |
src, tgt, mask = batch | |
src = src.to(device) | |
tgt = tgt.to(device) | |
mask = mask.to(device) | |
else: | |
if args.task == 'lm': | |
src, nodes, edges, connections, edge_mask = batch | |
tgt = src.detach().clone() | |
mask = (src != PROTEIN_ALPHABET.index(PAD)).float() | |
n, ell = src.shape | |
starts = torch.zeros(n, 1) + start_idx | |
starts = starts.long() | |
src = torch.cat([starts, src], dim=-1) | |
else: | |
src, tgt, mask, nodes, edges, connections, edge_mask = batch | |
src = src.to(device) | |
tgt = tgt.to(device) | |
mask = mask.to(device) | |
nodes = nodes.cuda() | |
edges = edges.cuda() | |
connections = connections.cuda() | |
edge_mask = edge_mask.cuda() | |
input_mask = (src != pad_idx).float().unsqueeze(-1) | |
if args.esm: | |
n, ell = src.shape | |
esm_src = torch.zeros(n, ell + 2) + alphabet.padding_idx | |
esm_src[:, 0] = alphabet.cls_idx | |
tokenized = [[to_esm[s.item()] for s in sr if s != pad_idx] + [alphabet.eos_idx] for sr in src] | |
ells = [] | |
for i, t in enumerate(tokenized): | |
el = len(t) | |
ells.append(el - 1) | |
esm_src[i, 1:el + 1] = torch.tensor(t) | |
esm_src = esm_src.to(device).long() | |
with torch.no_grad(): | |
e = encoder(esm_src)['logits'] | |
embeddings = torch.zeros(n, ell, 33) | |
embeddings = embeddings.to(device) | |
for i, (ee, el) in enumerate(zip(e, ells)): | |
embeddings[i, :el, :] = ee[1:el + 1] | |
embeddings = embeddings[:, :, dim_reorder] | |
else: | |
if args.freeze: | |
with torch.no_grad(): | |
embeddings = encoder(src, input_mask=input_mask) | |
else: | |
embeddings = encoder(src, input_mask=input_mask) | |
# slice out the extra position | |
if args.task == 'lm': | |
embeddings = embeddings[:, :-1, :] | |
if (args.esm and args.test) or args.no_gnn: | |
outputs = embeddings | |
else: | |
if args.logsoftmax: | |
embeddings = F.log_softmax(embeddings, dim=-1) | |
outputs = decoder(nodes, edges, connections, embeddings, edge_mask) | |
loss = loss_func(outputs, tgt, mask) | |
accu = accu_func(outputs, tgt, mask) | |
if train: | |
optimizer.zero_grad() | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
scheduler.step() | |
return loss.item(), accu.item(), mask.sum().item() | |
total_steps = 0 | |
n_parameters = sum(p.numel() for p in decoder.parameters()) | |
if rank == 0: | |
print('%d model parameters' %n_parameters) | |
print('%d training sequences' %len(len_train)) | |
print('%d validation sequences' %len(len_valid)) | |
if args.test: | |
e = 0 | |
with torch.no_grad(): | |
dl_valid = dl_test | |
_ = epoch(encoder, decoder, False) | |
return | |
best_valid_loss = 100 | |
patience = 20 | |
min_epochs = 500 | |
waiting = 0 | |
m_file = args.out_fpath + 'metrics.csv' | |
best_path = args.out_fpath + 'best.pt' | |
for e in range(epochs): | |
train_sortish_sampler.set_epoch(e) | |
total_steps += epoch(encoder, decoder, True, current_step=total_steps) | |
if rank == 0: | |
nsteps = total_steps | |
model_path = args.out_fpath + 'checkpoint%d.tar' % nsteps | |
torch.save({ | |
'step': nsteps, | |
'encoder_state_dict': encoder.state_dict(), | |
'decoder_state_dict': decoder.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict() | |
}, model_path) | |
with torch.no_grad(): | |
vloss = epoch(encoder, decoder, False, current_step=total_steps) | |
with open(m_file, 'a') as f: | |
f.write(str(nsteps)) | |
f.write(',') | |
f.write(str(vloss)) | |
f.write('\n') | |
if vloss < best_valid_loss: | |
best_valid_loss = vloss | |
waiting = 0 | |
best_epoch = e + 1 | |
best_path = model_path | |
# save_me = {'step': nsteps, 'valid_loss': vloss, 'epoch': e + 1} | |
# save_me['encoder_state_dict'] = encoder.state_dict() | |
# save_me['decoder_state_dict'] = decoder.state_dict() | |
# save_me['optimizer_state_dict'] = optimizer.state_dict() | |
# torch.save(save_me, best_path) | |
else: | |
waiting += 1 | |
if waiting >= patience and e > min_epochs: | |
break | |
if rank == 0 and dl_test is not None: | |
print('Loading checkpoint from epoch %d and testing...' %best_epoch) | |
sd = torch.load(best_path) | |
if not args.esm: | |
encoder.load_state_dict(sd['encoder_state_dict']) | |
encoder = encoder.eval() | |
decoder.load_state_dict(sd['decoder_state_dict']) | |
decoder = decoder.eval() | |
with torch.no_grad(): | |
dl_valid = dl_test | |
_ = epoch(encoder, decoder, False) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment