Created
March 9, 2021 18:09
-
-
Save DNGros/7c2fa0dcf566bd9f3732618669b591dd to your computer and use it in GitHub Desktop.
Quick CuBERT Huggingface Utils
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"attention_probs_dropout_prob": 0.1, | |
"directionality": "bidi", | |
"hidden_act": "gelu", | |
"hidden_dropout_prob": 0.1, | |
"hidden_size": 1024, | |
"initializer_range": 0.02, | |
"intermediate_size": 4096, | |
"max_position_embeddings": 512, | |
"num_attention_heads": 16, | |
"num_hidden_layers": 24, | |
"pooler_fc_size": 768, | |
"pooler_num_attention_heads": 12, | |
"pooler_num_fc_layers": 3, | |
"pooler_size_per_head": 128, | |
"pooler_type": "first_token_transform", | |
"type_vocab_size": 2, | |
"vocab_size": 49988 | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def combine_tokenizer_with_subword( | |
initial_tokenizer: CuBertTokenizer, | |
subword_tokenizer: text_encoder.SubwordTextEncoder | |
) -> Callable[[str], List[str]]: | |
# Try to match the functionality at | |
# https://github.com/google-research/google-research/blob/50c6cd94b5/cubert/code_to_subtokenized_sentences.py#L111-L118 | |
def tokenize(string: str) -> List[str]: | |
toks = initial_tokenizer.tokenize(string) | |
return flatten_list( | |
subword_tokenizer.decode_list( | |
subword_tokenizer.encode_without_tokenizing(token) | |
) | |
for token in toks | |
) | |
return tokenize | |
def flatten_list(t): | |
return [item for sublist in t for item in sublist] | |
class CuBertHugTokenizer(BertTokenizer): | |
# A hacky version that seems to work at least for python | |
def __init__( | |
self, | |
vocab_file: Path | |
): | |
super().__init__( | |
vocab_file=vocab_file, | |
do_lower_case=False, | |
do_basic_tokenize=True, | |
unk_token="[UNK]_", | |
sep_token="[SEP]_", | |
pad_token="<pad>_", | |
cls_token="[CLS]_", | |
mask_token="[MASK]_" | |
) | |
if not os.path.isfile(vocab_file): | |
raise ValueError( | |
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
vocab_file) | |
) | |
self.vocab = load_vocab(vocab_file) | |
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) | |
self.first_tokenizer = PythonTokenizer(50_000) | |
self.subword_tokenizer = text_encoder.SubwordTextEncoder(str(vocab_file)) | |
self._combined_func = combine_tokenizer_with_subword( | |
self.first_tokenizer, self.subword_tokenizer) | |
@property | |
def do_lower_case(self): | |
return False | |
def _tokenize(self, text): | |
return self._combined_func(text) | |
def convert_tokens_to_string(self, tokens): | |
raise NotImplementedError | |
def _convert_token_to_id(self, token): | |
return self.subword_tokenizer._subtoken_string_to_id[token] |
Well, actually, what I wanted to test was fill_mask
pipeline. I've tried it with tramsformers.pipeline
but as you can see in the above comment, the PythonTokenizer
does not support tokenizing mask...(Am I going in the right direction? I don't know...)
Do you know any way to check that the tokenizer works fine? I'm not sure this is a proper question, but I'm just grasping at straws.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @DNGros, thanks for sharing such an excellent code!
I have several questions, though.. could you help me?
load_vocab
function should I use? Is it in thetransformers.models.bert.tokenization_bert
? Or is it in thebert.tokenization
?load_vocab
, but I have got the following result for all of them:(1) The tokenizer (of course the
PythonTokenizer
) doesn't seem to recognize the[MASK]_
token, and (2) the___EOS___
token is weird too.Did you have any similar issues which you can give me some advice?
Thank you!