Last active
March 13, 2022 23:25
-
-
Save suryavanshi/1a3f3fc49f7c6b95f96464120c49f105 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
import random | |
unmasker = pipeline('fill-mask', model='bert-base-cased') | |
input_text = "I went to see a movie in the theater" | |
orig_text_list = input_text.split() | |
len_input = len(orig_text_list) | |
#Random index where we want to replace the word | |
rand_idx = random.randint(1,len_input-1) | |
orig_word = orig_text_list[rand_idx] | |
new_text_list = orig_text_list.copy() | |
new_text_list[rand_idx] = '[MASK]' | |
new_mask_sent = ' '.join(new_text_list) | |
print("Masked sentence->",new_mask_sent) | |
#I went to [MASK] a movie in the theater | |
augmented_text_list = unmasker(new_mask_sent) | |
#To ensure new word and old word are not name | |
for res in augmented_text_list: | |
if res['token_str'] != orig_word: | |
augmented_text = res['sequence'] | |
break | |
print("Augmented text->",augmented_text) | |
#I went to watch a movie in the theater |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thanks for pointing it out, while copying from my notebook to gist, I forgot to include that line - 'orig_word' is initial the word at the random index, so its - orig_word = orig_text_list[rand_idx]