Created
June 28, 2023 02:21
-
-
Save 10p-freddo/1c443a02c67af61436bbbf52ccb804f3 to your computer and use it in GitHub Desktop.
UniXcoder in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.functional import normalize | |
from .unixcoder import UniXcoder | |
class UniXcoderEmbeddings: | |
def __init__(self, model_name="microsoft/unixcoder-base"): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = UniXcoder(model_name).to(self.device) | |
def get_embedding(self, text): | |
tokens_ids = self.model.tokenize(text, max_length=512, mode="<encoder-only>") | |
source_ids = torch.tensor(tokens_ids).to(self.device) | |
_, embedding = self.model(source_ids) | |
return embedding | |
def similarity(self, nl_embedding, code_embedding): | |
norm_code_embedding = normalize(code_embedding, p=2, dim=1) | |
norm_nl_embedding = normalize(nl_embedding, p=2, dim=1) | |
similarity = torch.einsum("ac,bc->ab", norm_nl_embedding, norm_code_embedding) | |
return similarity.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import torch | |
import torch.nn as nn | |
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig | |
class UniXcoder(nn.Module): | |
def __init__(self, model_name): | |
""" | |
Build UniXcoder. | |
Parameters: | |
* `model_name`- huggingface model card name. e.g. microsoft/unixcoder-base | |
""" | |
super(UniXcoder, self).__init__() | |
self.tokenizer = RobertaTokenizer.from_pretrained(model_name) | |
self.config = RobertaConfig.from_pretrained(model_name) | |
self.config.is_decoder = True | |
self.model = RobertaModel.from_pretrained(model_name, config=self.config) | |
self.register_buffer("bias", torch.tril(torch.ones((1024, 1024), dtype=torch.uint8)).view(1,1024, 1024)) | |
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) | |
self.lm_head.weight = self.model.embeddings.word_embeddings.weight | |
self.lsm = nn.LogSoftmax(dim=-1) | |
self.tokenizer.add_tokens(["<mask0>"],special_tokens=True) | |
def tokenize(self, inputs, mode="<encoder-only>", max_length=512, padding=False): | |
""" | |
Convert string to token ids | |
Parameters: | |
* `inputs`- list of input strings. | |
* `max_length`- The maximum total source sequence length after tokenization. | |
* `padding`- whether to pad source sequence length to max_length. | |
* `mode`- which mode the sequence will use. i.e. <encoder-only>, <decoder-only>, <encoder-decoder> | |
""" | |
assert mode in ["<encoder-only>", "<decoder-only>", "<encoder-decoder>"] | |
assert max_length < 1024 | |
tokenizer = self.tokenizer | |
tokens_ids = [] | |
for x in inputs: | |
tokens = tokenizer.tokenize(x) | |
if mode == "<encoder-only>": | |
tokens = tokens[:max_length-4] | |
tokens = [tokenizer.cls_token,mode,tokenizer.sep_token] + tokens + [tokenizer.sep_token] | |
elif mode == "<decoder-only>": | |
tokens = tokens[-(max_length-3):] | |
tokens = [tokenizer.cls_token,mode,tokenizer.sep_token] + tokens | |
else: | |
tokens = tokens[:max_length-5] | |
tokens = [tokenizer.cls_token,mode,tokenizer.sep_token] + tokens + [tokenizer.sep_token] | |
tokens_id = tokenizer.convert_tokens_to_ids(tokens) | |
if padding: | |
tokens_id = tokens_id + [self.config.pad_token_id] * (max_length-len(tokens_id)) | |
tokens_ids.append(tokens_id) | |
return tokens_ids | |
def decode(self, source_ids): | |
""" Convert token ids to string """ | |
predictions = [] | |
for x in source_ids: | |
prediction = [] | |
for y in x: | |
t = y.cpu().numpy() | |
t = list(t) | |
if 0 in t: | |
t = t[:t.index(0)] | |
text = self.tokenizer.decode(t,clean_up_tokenization_spaces=False) | |
prediction.append(text) | |
predictions.append(prediction) | |
return predictions | |
def forward(self, source_ids): | |
""" Obtain token embeddings and sentence embeddings """ | |
mask = source_ids.ne(self.config.pad_token_id) | |
token_embeddings = self.model(source_ids,attention_mask = mask.unsqueeze(1) * mask.unsqueeze(2))[0] | |
sentence_embeddings = (token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1) | |
return token_embeddings, sentence_embeddings | |
def generate(self, source_ids, decoder_only = True, eos_id = None, beam_size = 5, max_length = 64): | |
""" Generate sequence given context (source_ids) """ | |
# Set encoder mask attention matrix: bidirectional for <encoder-decoder>, unirectional for <decoder-only> | |
if decoder_only: | |
mask = self.bias[:,:source_ids.size(-1),:source_ids.size(-1)] | |
else: | |
mask = source_ids.ne(self.config.pad_token_id) | |
mask = mask.unsqueeze(1) * mask.unsqueeze(2) | |
if eos_id is None: | |
eos_id = self.config.eos_token_id | |
device = source_ids.device | |
# Decoding using beam search | |
preds = [] | |
zero = torch.LongTensor(1).fill_(0).to(device) | |
source_len = list(source_ids.ne(1).sum(-1).cpu().numpy()) | |
length = source_ids.size(-1) | |
encoder_output = self.model(source_ids,attention_mask=mask) | |
for i in range(source_ids.shape[0]): | |
context = [[x[i:i+1,:,:source_len[i]].repeat(beam_size,1,1,1) for x in y] | |
for y in encoder_output.past_key_values] | |
beam = Beam(beam_size,eos_id,device) | |
input_ids = beam.getCurrentState().clone() | |
context_ids = source_ids[i:i+1,:source_len[i]].repeat(beam_size,1) | |
out = encoder_output.last_hidden_state[i:i+1,:source_len[i]].repeat(beam_size,1,1) | |
for _ in range(max_length): | |
if beam.done(): | |
break | |
if _ == 0: | |
hidden_states = out[:,-1,:] | |
out = self.lsm(self.lm_head(hidden_states)).data | |
beam.advance(out) | |
input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) | |
input_ids = beam.getCurrentState().clone() | |
else: | |
length = context_ids.size(-1)+input_ids.size(-1) | |
out = self.model(input_ids,attention_mask=self.bias[:,context_ids.size(-1):length,:length], | |
past_key_values=context).last_hidden_state | |
hidden_states = out[:,-1,:] | |
out = self.lsm(self.lm_head(hidden_states)).data | |
beam.advance(out) | |
input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) | |
input_ids = torch.cat((input_ids,beam.getCurrentState().clone()),-1) | |
hyp = beam.getHyp(beam.getFinal()) | |
pred = beam.buildTargetTokens(hyp)[:beam_size] | |
pred = [torch.cat([x.view(-1) for x in p]+[zero]*(max_length-len(p))).view(1,-1) for p in pred] | |
preds.append(torch.cat(pred,0).unsqueeze(0)) | |
preds = torch.cat(preds,0) | |
return preds | |
class Beam(object): | |
def __init__(self, size, eos, device): | |
self.size = size | |
self.device = device | |
# The score for each translation on the beam. | |
self.scores = torch.FloatTensor(size).zero_().to(device) | |
# The backpointers at each time-step. | |
self.prevKs = [] | |
# The outputs at each time-step. | |
self.nextYs = [torch.LongTensor(size).fill_(0).to(device)] | |
# Has EOS topped the beam yet. | |
self._eos = eos | |
self.eosTop = False | |
# Time and k pair for finished. | |
self.finished = [] | |
def getCurrentState(self): | |
"Get the outputs for the current timestep." | |
batch = self.nextYs[-1].view(-1, 1) | |
return batch | |
def getCurrentOrigin(self): | |
"Get the backpointers for the current timestep." | |
return self.prevKs[-1] | |
def advance(self, wordLk): | |
""" | |
Given prob over words for every last beam `wordLk` and attention | |
`attnOut`: Compute and update the beam search. | |
Parameters: | |
* `wordLk`- probs of advancing from the last step (K x words) | |
* `attnOut`- attention at the last step | |
Returns: True if beam search is complete. | |
""" | |
numWords = wordLk.size(1) | |
# Sum the previous scores. | |
if len(self.prevKs) > 0: | |
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) | |
# Don't let EOS have children. | |
for i in range(self.nextYs[-1].size(0)): | |
if self.nextYs[-1][i] == self._eos: | |
beamLk[i] = -1e20 | |
else: | |
beamLk = wordLk[0] | |
flatBeamLk = beamLk.view(-1) | |
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) | |
self.scores = bestScores | |
# bestScoresId is flattened beam x word array, so calculate which | |
# word and beam each score came from | |
prevK = torch.div(bestScoresId, numWords, rounding_mode="floor") | |
self.prevKs.append(prevK) | |
self.nextYs.append((bestScoresId - prevK * numWords)) | |
for i in range(self.nextYs[-1].size(0)): | |
if self.nextYs[-1][i] == self._eos: | |
s = self.scores[i] | |
self.finished.append((s, len(self.nextYs) - 1, i)) | |
# End condition is when top-of-beam is EOS and no global score. | |
if self.nextYs[-1][0] == self._eos: | |
self.eosTop = True | |
def done(self): | |
return self.eosTop and len(self.finished) >= self.size | |
def getFinal(self): | |
if len(self.finished) == 0: | |
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) | |
self.finished.sort(key=lambda a: -a[0]) | |
if len(self.finished) != self.size: | |
unfinished=[] | |
for i in range(self.nextYs[-1].size(0)): | |
if self.nextYs[-1][i] != self._eos: | |
s = self.scores[i] | |
unfinished.append((s, len(self.nextYs) - 1, i)) | |
unfinished.sort(key=lambda a: -a[0]) | |
self.finished+=unfinished[:self.size-len(self.finished)] | |
return self.finished[:self.size] | |
def getHyp(self, beam_res): | |
""" | |
Walk back to construct the full hypothesis. | |
""" | |
hyps=[] | |
for _,timestep, k in beam_res: | |
hyp = [] | |
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): | |
hyp.append(self.nextYs[j+1][k]) | |
k = self.prevKs[j][k] | |
hyps.append(hyp[::-1]) | |
return hyps | |
def buildTargetTokens(self, preds): | |
sentence=[] | |
for pred in preds: | |
tokens = [] | |
for tok in pred: | |
if tok==self._eos: | |
break | |
tokens.append(tok) | |
sentence.append(tokens) | |
return sentence |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from .unixcoder import UniXcoderEmbeddings | |
provider = UniXcoderEmbeddings() | |
# Encode maximum function | |
max_func = "def f(a,b): if a>b: return a else return b" | |
max_func_embedding = provider.get_embedding([max_func]) | |
# Encode minimum function | |
min_func = "def f(a,b): if a<b: return a else return b" | |
min_func_embedding = provider.get_embedding([min_func]) | |
# Encode natural language | |
nl = "return maximum value" | |
nl_embedding = provider.get_embedding([nl]) | |
# Calculate cosine similarity between NL and two functions | |
max_func_nl_similarity = provider.similarity(nl_embedding, max_func_embedding) | |
min_func_nl_similarity = provider.similarity(nl_embedding, min_func_embedding) | |
print(max_func_nl_similarity) | |
print(min_func_nl_similarity) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment