Created
August 26, 2020 05:02
-
-
Save MLWhiz/2ee403a5b5d7be7e12a37d90aef5b7ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(train_iter, val_iter, model, optim, num_epochs,use_gpu=True): | |
| train_losses = [] | |
| valid_losses = [] | |
| for epoch in range(num_epochs): | |
| train_loss = 0 | |
| valid_loss = 0 | |
| # Train model | |
| model.train() | |
| for i, batch in enumerate(train_iter): | |
| src = batch.src.cuda() if use_gpu else batch.src | |
| trg = batch.trg.cuda() if use_gpu else batch.trg | |
| #change to shape (bs , max_seq_len) | |
| src = src.transpose(0,1) | |
| #change to shape (bs , max_seq_len+1) , Since right shifted | |
| trg = trg.transpose(0,1) | |
| trg_input = trg[:, :-1] | |
| targets = trg[:, 1:].contiguous().view(-1) | |
| src_mask = (src != 0) | |
| src_mask = src_mask.float().masked_fill(src_mask == 0, float('-inf')).masked_fill(src_mask == 1, float(0.0)) | |
| src_mask = src_mask.cuda() if use_gpu else src_mask | |
| trg_mask = (trg_input != 0) | |
| trg_mask = trg_mask.float().masked_fill(trg_mask == 0, float('-inf')).masked_fill(trg_mask == 1, float(0.0)) | |
| trg_mask = trg_mask.cuda() if use_gpu else trg_mask | |
| size = trg_input.size(1) | |
| #print(size) | |
| np_mask = torch.triu(torch.ones(size, size)==1).transpose(0,1) | |
| np_mask = np_mask.float().masked_fill(np_mask == 0, float('-inf')).masked_fill(np_mask == 1, float(0.0)) | |
| np_mask = np_mask.cuda() if use_gpu else np_mask | |
| # Forward, backprop, optimizer | |
| optim.zero_grad() | |
| preds = model(src.transpose(0,1), trg_input.transpose(0,1), tgt_mask = np_mask)#, src_mask = src_mask)#, tgt_key_padding_mask=trg_mask) | |
| preds = preds.transpose(0,1).contiguous().view(-1, preds.size(-1)) | |
| loss = F.cross_entropy(preds,targets, ignore_index=0,reduction='sum') | |
| loss.backward() | |
| optim.step() | |
| train_loss += loss.item()/BATCH_SIZE | |
| model.eval() | |
| with torch.no_grad(): | |
| for i, batch in enumerate(val_iter): | |
| src = batch.src.cuda() if use_gpu else batch.src | |
| trg = batch.trg.cuda() if use_gpu else batch.trg | |
| #change to shape (bs , max_seq_len) | |
| src = src.transpose(0,1) | |
| #change to shape (bs , max_seq_len+1) , Since right shifted | |
| trg = trg.transpose(0,1) | |
| trg_input = trg[:, :-1] | |
| targets = trg[:, 1:].contiguous().view(-1) | |
| src_mask = (src != 0) | |
| src_mask = src_mask.float().masked_fill(src_mask == 0, float('-inf')).masked_fill(src_mask == 1, float(0.0)) | |
| src_mask = src_mask.cuda() if use_gpu else src_mask | |
| trg_mask = (trg_input != 0) | |
| trg_mask = trg_mask.float().masked_fill(trg_mask == 0, float('-inf')).masked_fill(trg_mask == 1, float(0.0)) | |
| trg_mask = trg_mask.cuda() if use_gpu else trg_mask | |
| size = trg_input.size(1) | |
| #print(size) | |
| np_mask = torch.triu(torch.ones(size, size)==1).transpose(0,1) | |
| np_mask = np_mask.float().masked_fill(np_mask == 0, float('-inf')).masked_fill(np_mask == 1, float(0.0)) | |
| np_mask = np_mask.cuda() if use_gpu else np_mask | |
| preds = model(src.transpose(0,1), trg_input.transpose(0,1), tgt_mask = np_mask)#, src_mask = src_mask)#, tgt_key_padding_mask=trg_mask) | |
| preds = preds.transpose(0,1).contiguous().view(-1, preds.size(-1)) | |
| loss = F.cross_entropy(preds,targets, ignore_index=0,reduction='sum') | |
| valid_loss += loss.item()/1 | |
| # Log after each epoch | |
| print(f'''Epoch [{epoch+1}/{num_epochs}] complete. Train Loss: {train_loss/len(train_iter):.3f}. Val Loss: {valid_loss/len(val_iter):.3f}''') | |
| #Save best model till now: | |
| if valid_loss/len(val_iter)<min(valid_losses,default=1e9): | |
| print("saving state dict") | |
| torch.save(model.state_dict(), f"checkpoint_best_epoch.pt") | |
| train_losses.append(train_loss/len(train_iter)) | |
| valid_losses.append(valid_loss/len(val_iter)) | |
| # Check Example after each epoch: | |
| sentences = ["This is an example to check how our model is performing."] | |
| for sentence in sentences: | |
| print(f"Original Sentence: {sentence}") | |
| print(f"Translated Sentence: {greeedy_decode_sentence(model,sentence)}") | |
| return train_losses,valid_losses |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment