Created
May 17, 2021 01:47
-
-
Save guillefix/c68ccf4a9f73a46afaced1164c4a14eb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from .transformer import BasicTransformerModel | |
from models import BaseModel | |
from .util.generation import autoregressive_generation_multimodal | |
class TransformerModel(BaseModel): | |
def __init__(self, opt): | |
super().__init__(opt) | |
opt=self.opt | |
input_mods = self.input_mods | |
output_mods = self.output_mods | |
dins = self.dins | |
douts = self.douts | |
input_lengths = self.input_lengths | |
self.input_mod_nets = [] | |
self.output_mod_nets = [] | |
self.module_names = [] | |
for i, mod in enumerate(input_mods): | |
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_lengths[i], use_x_transformers=opt.use_x_transformers, opt=opt) | |
name = "_input_"+mod | |
setattr(self,"net"+name, net) | |
self.input_mod_nets.append(net) | |
self.module_names.append(name) | |
for i, mod in enumerate(output_mods): | |
net = BasicTransformerModel(douts[i], opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths), use_x_transformers=opt.use_x_transformers, opt=opt) | |
# net = BasicTransformerModel(douts[i], opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=True, input_length=sum(input_lengths)) | |
name = "_output_"+mod | |
setattr(self,"net"+name, net) | |
self.output_mod_nets.append(net) | |
self.module_names.append(name) | |
#This is feature creep. Will remove soon | |
# if self.opt.generate_attention_masks: | |
self.generate_full_masks() | |
self.inputs = [] | |
self.targets = [] | |
self.criterion = nn.MSELoss() | |
def name(self): | |
return "Transformer" | |
@staticmethod | |
def modify_commandline_options(parser, opt): | |
parser.add_argument('--dhid', type=int, default=512) | |
parser.add_argument('--nlayers', type=int, default=6) | |
parser.add_argument('--nhead', type=int, default=8) | |
parser.add_argument('--dropout', type=float, default=0.1) | |
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers") | |
parser.add_argument('--use_rotary_pos_emb', action='store_true', help="whether to use rotary position embeddings") | |
parser.add_argument('--use_x_transformers', action='store_true', help="whether to use rotary position embeddings") | |
# parser.add_argument('--generate_attention_masks', action='store_true', help="whether to generate the masks (but right now they are full masks, so it's not necessary") | |
return parser | |
def generate_full_masks(self): | |
input_mods = self.input_mods | |
output_mods = self.output_mods | |
input_lengths = self.input_lengths | |
self.src_masks = [] | |
for i, mod in enumerate(input_mods): | |
mask = torch.zeros(input_lengths[i],input_lengths[i]) | |
self.register_buffer('src_mask_'+str(i), mask) | |
self.src_masks.append(mask) | |
self.output_masks = [] | |
for i, mod in enumerate(output_mods): | |
mask = torch.zeros(sum(input_lengths),sum(input_lengths)) | |
self.register_buffer('out_mask_'+str(i), mask) | |
self.output_masks.append(mask) | |
def forward(self, data): | |
# in lightning, forward defines the prediction/inference actions | |
latents = [] | |
for i, mod in enumerate(self.input_mods): | |
latents.append(self.input_mod_nets[i].forward(data[i])) | |
latent = torch.cat(latents) | |
outputs = [] | |
for i, mod in enumerate(self.output_mods): | |
output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]] | |
outputs.append(output) | |
#import pdb;pdb.set_trace() | |
return outputs | |
def training_step(self, batch, batch_idx): | |
self.set_inputs(batch) | |
#print(self.inputs) | |
latents = [] | |
for i, mod in enumerate(self.input_mods): | |
latents.append(self.input_mod_nets[i].forward(self.inputs[i])) | |
latent = torch.cat(latents) | |
loss_mse = 0 | |
for i, mod in enumerate(self.output_mods): | |
output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]] | |
#print(output) | |
loss_mse += self.criterion(output, self.targets[i]) | |
#loss_mse += self.criterion(output, self.targets[i]).detach() | |
#print(loss_mse) | |
#if self.opt.precision == 16: | |
# loss_mse *= 100 # loss scaling | |
self.log('mse_loss', loss_mse) | |
return loss_mse | |
#return torch.tensor(0.0, dtype=torch.float32, requires_grad=True) | |
#def configure_optimizers(self): | |
# print("HIIIIIIIIIIIIIIIIII") | |
# optimizer = torch.optim.Adam(self.parameters(), lr=self.opt.learning_rate) | |
# return [optimizer] | |
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, | |
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs): | |
# optimizer.zero_grad() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment