Created
May 17, 2021 01:48
-
-
Save guillefix/8000aaf32afdae2496eedf87e829057d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir)) | |
sys.path.append(ROOT_DIR) | |
sys.path.append(THIS_DIR) | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import uuid | |
import numpy as np | |
from functools import partial | |
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Transformer | |
#from models.x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper | |
from x_transformers import ContinuousTransformerWrapper, Decoder, Encoder, AutoregressiveWrapper | |
def pick_and_pop(keys, d): | |
values = list(map(lambda key: d.pop(key), keys)) | |
return dict(zip(keys, values)) | |
def groupby_prefix_and_trim(prefix, d): | |
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) | |
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) | |
return kwargs_without_prefix, kwargs | |
def group_dict_by_key(cond, d): | |
return_val = [dict(),dict()] | |
for key in d.keys(): | |
match = bool(cond(key)) | |
ind = int(not match) | |
return_val[ind][key] = d[key] | |
return (*return_val,) | |
def string_begins_with(prefix, str): | |
return str.startswith(prefix) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000, device=None): | |
super(PositionalEncoding, self).__init__() | |
self.device = device | |
self.dropout = nn.Dropout(p=dropout) | |
self.lpe = nn.Embedding(max_len+1, d_model) | |
# self.weight = None | |
self.indices = torch.arange(max_len).unsqueeze(1) + 1 | |
if device is not None: | |
self.indices = self.indices.to(self.device) | |
def init_weights(self): | |
initrange = 0.1 | |
self.lpe.weight.data.uniform_(-initrange, initrange) | |
def forward(self, x, indices = None): | |
np.save(str(uuid.uuid4())+".np",self.lpe.weight.data.cpu().numpy()) | |
if indices is None: | |
indices = self.indices[:x.size(0),:] | |
indices = self.dropout(indices) | |
x = x + self.lpe(indices) | |
return self.dropout(x) | |
class LearnedPositionalEncoding(nn.Module): # emm this isn't learned lol | |
def __init__(self, d_model, dropout=0.1, max_len=5000, device=None): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term1 = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
div_term2 = torch.exp(torch.arange(0, (d_model//2)*2, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term1) | |
pe[:, 1::2] = torch.cos(position * div_term2) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# print(x.shape) | |
x = x + self.pe[:x.size(0), :] | |
return self.dropout(x) | |
class BasicTransformerModel(nn.Module): | |
def __init__(self, dout, dinp, nhead, dhid, nlayers, dropout=0.5,device=None,use_pos_emb=False,input_length=0,use_x_transformers=False,opt=None): | |
super(BasicTransformerModel, self).__init__() | |
self.device = device | |
self.model_type = 'Transformer' | |
self.use_x_transformers = use_x_transformers | |
if not use_x_transformers: | |
self.encoder1 = nn.Linear(dinp, dhid) | |
#self.pos_encoder = PositionalEncoding(dhid, dropout, device=self.device) | |
encoder_layers = TransformerEncoderLayer(dhid, nhead, dhid, dropout) | |
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) | |
# self.encoder = nn.Embedding(ntoken, dinp) | |
self.dinp = dinp | |
self.dhid = dhid | |
self.decoder = nn.Linear(dhid, dout) | |
self.use_pos_emb = use_pos_emb | |
if use_pos_emb: | |
assert input_length > 0 | |
self.pos_emb = nn.Parameter((torch.zeros(input_length, input_length))) | |
# self.pos_emb = nn.Parameter((torch.eye(input_length, input_length))) | |
# self.pos_emb = nn.Parameter((torch.randn(input_length, input_length))/np.sqrt(dinp)) | |
self.init_weights() | |
#self.pos_encoder.init_weights() | |
else: | |
self.model = ContinuousTransformerWrapper( | |
dim_in = dinp, | |
dim_out = dout, | |
max_seq_len = 1024, | |
use_pos_emb = use_pos_emb, | |
attn_layers = Encoder( | |
dim = dhid, | |
depth = nlayers, | |
heads = nhead, | |
rotary_pos_emb = opt.use_rotary_pos_emb, | |
#rel_pos_bias = True | |
) | |
) | |
def generate_square_subsequent_mask(self, sz, prefix_length = 1): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask[:,:prefix_length] = 1 | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def init_weights(self): | |
initrange = 0.1 | |
# self.encoder.weight.data.uniform_(-initrange, initrange) | |
self.decoder.bias.data.zero_() | |
self.decoder.weight.data.uniform_(-initrange, initrange) | |
def forward(self, src, src_mask=None): | |
if not self.use_x_transformers: | |
# import pdb;pdb.set_trace() | |
src = self.encoder1(src) | |
#src *= math.sqrt(self.dhid) | |
#src = self.pos_encoder(src) | |
#src /= math.sqrt(self.dhid) | |
# print(src) | |
# print(torch.mm(src[:,0,:],src[:,0,:].T)) | |
if self.use_pos_emb: | |
#print(self.pos_emb) | |
#src_mask += self.pos_emb | |
if src_mask is not None: | |
output = self.transformer_encoder(src, src_mask + self.pos_emb) | |
else: | |
output = self.transformer_encoder(src, self.pos_emb) | |
#output = self.transformer_encoder(src, self.pos_emb) | |
else: | |
if src_mask is not None: | |
output = self.transformer_encoder(src, src_mask) | |
else: | |
output = self.transformer_encoder(src) | |
#output = self.transformer_encoder(src) | |
output = self.decoder(output) | |
return output | |
else: | |
assert src_mask == None | |
src = src.permute(1,0,2) | |
mask = torch.ones(src.shape[0], src.shape[1]).bool().cuda() | |
output = self.model(src, mask = mask) | |
# output = self.model(src.permute(1,0,2)) | |
return output.permute(1,0,2) | |
class EncDecTransformerModel(nn.Module): | |
def __init__(self, dout, src_d, tgt_d, nhead, dhid, nlayers, dropout=0.5,device=None,use_pos_emb=False,src_length=0,tgt_length=0,use_x_transformers=False,opt=None): | |
super(EncDecTransformerModel, self).__init__() | |
self.device = device | |
self.model_type = 'Transformer' | |
self.use_x_transformers = use_x_transformers | |
self.encoder1 = nn.Linear(src_d, dhid) | |
self.encoder2 = nn.Linear(tgt_d, dhid) | |
if not use_x_transformers: | |
self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=0, activation="relu") | |
#enc_layer = TransformerEncoderLayer(d_model=dhid, nhead=nhead, dropout=0, activation="relu") | |
#self.transformerEnc = TransformerEncoder(enc_layer, nlayers) | |
else: | |
self.transformer = EncDecXTransformer(enc_dim_in=src_d, enc_dim_out=tgt_d, dec_din_in=tgt_d, edec_dim_out=dout, enc_dim=dhid, dec_dim=dhid, nc_heads=nhead, dec_heads=nhead, enc_depth=nlayers, dec_depth=nlayers, enc_dropout=dropout, dec_dropout=dropout, enc_max_seq_len=1024, dec_max_seq_len=1024) | |
#xdecoder = Decoder(dim=dhid, depth=nlayers, heads=nhead, cross_attend=True) | |
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers, dropout=dropout, activation="gelu", custom_decoder=xdecoder) | |
#self.transformer = Transformer(d_model=dhid, nhead=nhead, num_encoder_layers=nlayers, num_decoder_layers=nlayers) | |
# self.encoder = nn.Embedding(ntoken, dinp) | |
self.src_d = src_d | |
self.tgt_d = tgt_d | |
self.dhid = dhid | |
self.decoder = nn.Linear(dhid, dout) | |
self.use_pos_emb = use_pos_emb | |
if use_pos_emb: | |
assert src_length > 0 | |
assert tgt_length > 0 | |
self.src_pos_emb = nn.Parameter((torch.zeros(src_length, src_length))) | |
self.tgt_pos_emb = nn.Parameter((torch.zeros(tgt_length, tgt_length))) | |
if not use_x_transformers: | |
tgt_mask = self.generate_square_subsequent_mask(tgt_length) | |
else: | |
tgt_mask = self.generate_square_subsequent_mask_bool(tgt_length) | |
self.register_buffer("tgt_mask", tgt_mask) | |
#a = torch.randn(32,3,512) | |
#b = torch.randn(32,3,512) | |
#self.register_buffer('a', a) | |
#self.register_buffer('b', b) | |
self.init_weights() | |
def generate_square_subsequent_mask(self, sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |
def generate_square_subsequent_mask_bool(self, sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).bool() | |
return mask | |
def init_weights(self): | |
initrange = 0.1 | |
# self.encoder.weight.data.uniform_(-initrange, initrange) | |
self.decoder.bias.data.zero_() | |
self.decoder.weight.data.uniform_(-initrange, initrange) | |
def forward(self, src, tgt): | |
if not self.use_x_transformers: | |
# import pdb;pdb.set_trace() | |
src = self.encoder1(src) | |
tgt = self.encoder2(tgt) | |
tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]] | |
if self.use_pos_emb: | |
tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]] | |
# import pdb;pdb.set_trace() | |
output = self.transformer(src=src, tgt=tgt, src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask) | |
#output = self.transformer(src=src, tgt=tgt) | |
else: | |
output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) | |
#output = self.transformer(src=self.a, tgt=self.b) | |
#output = self.transformerEnc(src=self.a) | |
output = self.decoder(output) | |
return output | |
else: | |
src = self.encoder1(src) | |
tgt = self.encoder2(tgt) | |
#tgt_mask = self.tgt_mask[:tgt.shape[0], :tgt.shape[0]] | |
# if self.use_pos_emb: | |
# tgt_pos_emb = self.tgt_pos_emb[:tgt.shape[0], :tgt.shape[0]] | |
# # import pdb;pdb.set_trace() | |
# output = self.transformer(src=src.permute(1,2,0), tgt=tgt.permute(1,2,0), src_mask=self.src_pos_emb, tgt_mask=tgt_pos_emb+tgt_mask) | |
# #output = self.transformer(src=src, tgt=tgt) | |
# else: | |
# output = self.transformer(src=src.permute(1,0,2), tgt=tgt.permute(1,0,2), tgt_mask=tgt_mask) | |
# output = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask) | |
output = self.transformer(src=src, tgt=tgt) | |
#output = self.transformer(src=self.a, tgt=self.b) | |
#output = self.transformer(src=tgt, tgt=tgt) #hmm thats an interesting way of residual attention | |
output = self.decoder(output) | |
return output | |
#return | |
class EncDecXTransformer(nn.Module): | |
def __init__( | |
self, | |
*, | |
# dim, | |
tie_token_emb = False, | |
**kwargs | |
): | |
super().__init__() | |
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs) | |
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs) | |
# import pdb;pdb.set_trace() | |
# assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword' | |
# enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs) | |
enc_transformer_kwargs = pick_and_pop(['max_seq_len'], enc_kwargs) | |
# enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None) | |
# dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs) | |
dec_transformer_kwargs = pick_and_pop(['max_seq_len'], dec_kwargs) | |
self.encoder = ContinuousTransformerWrapper( | |
**enc_transformer_kwargs, | |
attn_layers = Encoder(**enc_kwargs) | |
) | |
self.decoder = ContinuousTransformerWrapper( | |
**dec_transformer_kwargs, | |
attn_layers = Decoder(cross_attend = True, **dec_kwargs) | |
) | |
if tie_token_emb: | |
self.decoder.token_emb = self.encoder.token_emb | |
# self.decoder = AutoregressiveWrapper(self.decoder) | |
@torch.no_grad() | |
def generate(self, seq_in, seq_out_start, seq_len, src_mask = None): | |
encodings = self.encoder(seq_in, return_embeddings = True, mask = src_mask) | |
return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = src_mask) | |
def forward(self, src, tgt, src_mask = None, tgt_mask = None): | |
enc = self.encoder(src, mask = src_mask, return_embeddings = True) | |
#out = self.decoder(tgt, context = enc, mask = tgt_mask, context_mask = src_mask) | |
out = self.decoder(tgt, context = enc) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment