Skip to content

Instantly share code, notes, and snippets.

@BMPixel
Created April 24, 2024 03:05
Show Gist options
  • Save BMPixel/33572b8a90b7b30ed310ec923493abd6 to your computer and use it in GitHub Desktop.
Save BMPixel/33572b8a90b7b30ed310ec923493abd6 to your computer and use it in GitHub Desktop.
from collections import Counter
import pickle
from transformers import PreTrainedTokenizerFast
import json
# Load the base tokenizer from the model
base_tokenizer = PreTrainedTokenizerFast.from_pretrained("/cephfs/panwenbo/work/models/Meta-Llama-3-8B")
# Read all non-special tokens from llama 3 8B tokenizer
base_toks = list(base_tokenizer.get_vocab().keys())
base_toks = [t for t in base_toks if not t.startswith('<|')]
# Read all added Chinese tokens (你好) from the zh tokenizer
ext_tokenizer = pickle.load(open('/cephfs/yueli/serialized_object.pkl', 'rb'))
ext_words_raw = list(ext_tokenizer.get_added_vocab().keys())
special_toks = [t for t in ext_words_raw if t.startswith('<')]
ext_words = [t for t in ext_words_raw if not t.startswith('<')]
special_tok2id = {t: ext_tokenizer.encode(t)[0] for t in special_toks}
# Some words appear in zh tokenizer added_token, but also in the base tokenizer vocab. Their ids should follow the base tokenizer.
base_and_ext = set(base_toks) & set(ext_words)
print(f"Base and ext words: {base_and_ext}")
ext_words = [w for w in ext_words if w not in base_and_ext]
# Verify that the three sets are disjoint and they add up to the full vocab
assert len(set(base_toks) & set(ext_words)) == 0
assert len(set(base_toks) & set(special_toks)) == 0
assert len(set(ext_words) & set(special_toks)) == 0
all_union = set(base_toks) | set(ext_words) | set(special_toks)
all_ext_words = set(ext_tokenizer.get_vocab().keys())
assert all_ext_words == all_union
# Transform ext_words into ext_tokens (你好 -> ä½ł). tokenizer.json only accepts tokens, not words. The tokens can be accquired by training a new tokenizer from the words.
expanded_tokenizer = base_tokenizer.train_new_from_iterator(ext_words, length=len(ext_words), vocab_size=len(ext_words) * 1000)
ext_tok2id = {}
ext_word2tok = {}
for w in ext_words:
id_ext = ext_tokenizer.encode(w)[0]
ids = expanded_tokenizer.encode(w)
if len(ids) != 1:
print(f"Token {w} has multiple ids: {ids}")
tok = "".join(expanded_tokenizer.convert_ids_to_tokens(ids))
else:
tok = expanded_tokenizer.convert_ids_to_tokens(ids)[0]
ext_tok2id[tok] = id_ext
ext_word2tok[w] = tok
ext_toks = list(ext_tok2id.keys())
all_toks_set = set(ext_toks) | set(base_toks)
print("DISPLAYING SOME TOKENS")
print("Ext words", len(ext_words))
print(ext_words[-20:])
print("Ext toks", len(ext_toks))
print(ext_toks[-20:])
print("Base toks", len(base_toks))
print(base_toks[-20:])
print("Special toks", len(special_toks))
print(special_toks[-20:])
# Create merges manually as merges created from training is not reliable
# First we count the frequency of each token
tokcount = Counter(ext_toks)
for tok in ext_toks:
for sp in range(1, len(tok)):
if tok[:sp] in all_toks_set:
tokcount[tok[:sp]] += 1
if tok[sp:] in all_toks_set:
tokcount[tok[sp:]] += 1
# Then we create merges by dividing each token into two parts and check if both parts are in the vocab
# Neither order exactly matchs the behaviour of ext_tokenizer
ext_toks = sorted(ext_toks, key=lambda x: (-tokcount[x], x)) # This sort tokens by their frequency
# ext_toks = sorted(ext_toks, key=lambda x: ext_tok2id[x]) # This sort tokens by their idx
merges =[]
for tok in ext_toks:
for sp in range(1, len(tok)):
le = tok[:sp]
ri = tok[sp:]
if le in all_toks_set and ri in all_toks_set:
merges.append(f"{le} {ri}")
print("Merges size", len(merges))
# Create a expanded tokenizer by injecting new tokens and merges into the base tokenizer
ext_tokenizer.save_pretrained('tar_tokenizer')
tokenizer_obj = json.load(open('tar_tokenizer/tokenizer.json'))
tokenizer_cfg = json.load(open('tar_tokenizer/tokenizer_config.json'))
added_tokens_decoder = {
str(i): {
'content': t,
'single_word': False,
'lstrip': False,
'rstrip': False,
'special': True,
'normalized': False
} for t, i in special_tok2id.items()
}
tokenizer_cfg['added_tokens_decoder'] = added_tokens_decoder
added_tokens = [
{
"id": int(i),
**token
} for i, token in added_tokens_decoder.items()
]
tokenizer_obj['added_tokens'] = added_tokens
orginal_merge_set = set(tokenizer_obj['model']['merges'])
filtered_merges = [m for m in merges if m not in orginal_merge_set]
updated_merges = filtered_merges + tokenizer_obj['model']['merges']
tokenizer_obj['model']['merges'] = updated_merges
updated_vocab = {**tokenizer_obj['model']['vocab'], **ext_tok2id, **special_tok2id}
tokenizer_obj['model']['vocab'] = updated_vocab
json.dump(tokenizer_obj, open('tar_tokenizer/tokenizer.json', 'w'), indent=2, ensure_ascii=False)
json.dump(tokenizer_cfg, open('tar_tokenizer/tokenizer_config.json', 'w'), indent=2, ensure_ascii=False)
reloaded_tokenizer = PreTrainedTokenizerFast.from_pretrained('tar_tokenizer')
# Some simple tests, the expanded tokenizer is slightly different from the expected tokenizer when tokenizing Chinese
tests = """你好中国!
你好
在这个快速变化的世界中,保持竞争力并寻找新的机会来发展和成长
we must constantly adapt and learn to stay ahead
I am a student, and I am learning to code
《哈利波特》是一部很好看的电影
<s> 你怎么了 </s>
是两种昆虫
<|begin_of_text|> 你好 <|end_of_text|>
发展和成长
2019冠状病毒病疫情[11][注 4]是由严重急性呼吸系统综合征冠状病毒2(SARS-CoV-2)导致的2019冠状病毒病(COVID-19)所引发的全球大流行疫情[12]。该疾病在2019年末于中华人民共和国湖北省武汉市首次报告,随后在2020年初全球多国报告发现病例,逐渐变成一场全球性大瘟疫[13]。截至2024年4月22日,全球已累计报告逾775,293,616[9]例确诊病例,其中逾7,044,637[9]人死亡[14],病死率约为2.09%[15],是人类历史上大规模流行病之一。世界各国对该病病死率的估计值差异甚大,多数国家该病的观测病死率在0.5%-5.0%之间[16][注 5]。
目前研究表明,SARS-CoV-2最早可能于2019年10月至11月进入人类社会生活并开始传播[18][19][20],而目前明确已知的首宗感染个案于2019年12月1日在武汉市发病[21][注 3]。首位前往医院就诊的患者可能出现于12月12日[24]。12月26日,武汉市呼吸与重症医学科医生张继先最早发现和上报此不明原因肺炎,并怀疑该病属传染病[25][26][27]。2020年1月13日起,疫情陆续蔓延到泰国、日本及韩国等相邻国家[28][29][30],至1月21日则波及到亚洲以外的美国西雅图[31]。1月23日,武汉市新冠肺炎疫情防控指挥部宣布采取疫区封锁隔离措施[32][33],这是近代公共卫生史上第一次对千万人口规模的大城市采取封锁措施[34]。在1月30日,中国境外有3个国家证实出现社区传播,而世界卫生组织亦于当日宣布疫情为“国际关注的突发公共卫生事件”。2月中旬,中国大陆的疫情达到发展高峰,而2月底意大利、韩国与伊朗三国的确诊人数急速增加。2月29日,世卫组织将疫情的全球风险级别提升为“非常高”[35]。3月11日,欧洲与中东各国都出现了大量病例,世卫组织宣布此次疫情已构成“全球大流行”[36][37][38]。此后欧洲[39]、南美洲[40]先后被宣布为本次大流行的中心。10月5日,世卫组织表示,根据“最确切推算”,全球约10%的人口可能已感染病毒[41][42]。截至2021年5月21日,根据世界卫生组织的估计,真正的死亡人数可能高达官方报告的2-3倍,约至少600万-800万人[43]。
"""
for test in tests.split('\n'):
print(test)
ext_encoded = ext_tokenizer.encode(test)
ext_decoded = ext_tokenizer.decode(ext_encoded)
reloaded_encoded = reloaded_tokenizer.encode(test)
reloaded_decoded = reloaded_tokenizer.decode(reloaded_encoded)
ext_decoded_from_reloaded = ext_tokenizer.decode(reloaded_encoded)
reloaded_decoded_from_ext = reloaded_tokenizer.decode(ext_encoded)
assert ext_decoded == reloaded_decoded == test == ext_decoded_from_reloaded == reloaded_decoded_from_ext
print(f"Ext: {ext_encoded}")
print(f"Rel: {reloaded_encoded}")
print('-'*100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment