Created
June 21, 2021 07:17
-
-
Save macleginn/7a60126cf805b6dcb58a364973777deb to your computer and use it in GitHub Desktop.
Code for extracting word embeddings from RoBERTa
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def rm_whitespace(s): | |
if s.startswith('Ġ'): | |
return s[1:] | |
else: | |
return s | |
def get_tokens_with_ranges(input_string, tokenizer): | |
''' | |
RoBERTa prepends 'Ġ' to the beginning of what it | |
thinks to be a word in the input, except the first one | |
(when it is not prefixed with a whitespace). E.g.: | |
``` | |
In [30]: tokenizer.tokenize('I say geography') | |
Out[30]: ['I', 'Ġsay', 'Ġgeography'] | |
In [29]: tokenizer.tokenize(' i say chronomoscopy') | |
Out[29]: ['Ġi', 'Ġsay', 'Ġchron', 'om', 'osc', 'opy'] | |
``` | |
This function returns an array containing ranges for | |
tokens of each word together with PyTorch tensors for | |
the tokens. | |
''' | |
assert input_string | |
tokens = tokenizer.tokenize(input_string) | |
ranges = [ | |
[0, 1] # Start-of-sentence token | |
] | |
tmp = [] | |
for i, token in enumerate(tokens): | |
idx = i + 1 # 0 is <s> | |
if not tmp: | |
tmp.append(idx) | |
else: | |
if token.startswith('Ġ'): | |
ranges.append([tmp[0], tmp[-1]+1]) | |
tmp = [idx] | |
else: | |
tmp.append(idx) | |
ranges.append([tmp[0], tmp[-1] + 1]) | |
ranges.append([len(tokens) + 1, len(tokens) + 2]) # End-of-sentence token | |
return ranges, tokenizer(input_string, return_tensors='pt') | |
def get_word_embeddings(input_string, tokenizer, model, level): | |
input_tokens = tokenizer.tokenize(input_string) | |
ranges, inputs = get_tokens_with_ranges(input_string, tokenizer) | |
with torch.no_grad(): | |
if level == 'final': | |
# No batches for now | |
outputs = model(**inputs).last_hidden_state[0] | |
else: | |
outputs = model( | |
**inputs, output_hidden_states=True).hidden_states[level-1][0] | |
return { | |
rm_whitespace(''.join(input_tokens[start-1:end-1])): | |
outputs[start-1:end-1, :].mean(0).numpy() | |
for start, end in ranges[1:-1] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment