Last active
April 29, 2023 09:29
-
-
Save K024/4a100a0f4f4b07208958e0f3244da6ad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# trim.py | |
# trim the vocabulary of mt5 model in huggingface.co | |
# MIT License | |
# Copyright (c) 2022 K024 | |
# %% | |
import torch | |
from tqdm.auto import tqdm | |
# %% | |
from transformers import T5Tokenizer | |
local = "./mt5-small" | |
target = "./mt5-trimmed" | |
tokenizer = T5Tokenizer.from_pretrained(local) | |
state_dict = torch.load(local + "/pytorch_model.bin") | |
# most_common = counter.most_common()[:85000] | |
# # 3 special tokens and 256 byte fallback | |
# keep_ids = sorted(set(range(259)) | set(x[0] for x in most_common)) | |
keep_ids = torch.load("./keep_ids.pth") | |
# %% | |
from sentencepiece import sentencepiece_model_pb2 as spm | |
proto = spm.ModelProto() | |
with open(local + "/spiece.model", 'rb') as f: | |
proto.ParseFromString(f.read()) | |
# %% | |
sp_target = spm.ModelProto() | |
with open(local + "/spiece.model", 'rb') as f: | |
sp_target.ParseFromString(f.read()) | |
del sp_target.pieces[:] | |
# %% | |
shared_weight = state_dict['shared.weight'] | |
lm_head = state_dict['lm_head.weight'] | |
shared_weight_target = [] | |
lm_head_target = [] | |
# %% | |
for i, idx in enumerate(tqdm(keep_ids)): | |
assert len(sp_target.pieces) == i | |
assert len(shared_weight_target) == i | |
assert len(lm_head_target) == i | |
sp_target.pieces.append(proto.pieces[idx]) | |
shared_weight_target.append(shared_weight[idx]) | |
lm_head_target.append(lm_head[idx]) | |
# <extra_id_xx> | |
for idx in range(250000, len(proto.pieces)): | |
sp_target.pieces.append(proto.pieces[idx]) | |
shared_weight_target.append(shared_weight[idx]) | |
lm_head_target.append(lm_head[idx]) | |
# reserved for additional_special_tokens | |
for idx in range(len(proto.pieces), len(shared_weight)): | |
shared_weight_target.append(shared_weight[idx]) | |
lm_head_target.append(lm_head[idx]) | |
shared_weight_target = torch.stack(shared_weight_target) | |
lm_head_target = torch.stack(lm_head_target) | |
# %% | |
import os | |
os.makedirs(target, exist_ok=True) | |
state_dict['shared.weight'] = shared_weight_target | |
state_dict['encoder.embed_tokens.weight'] = shared_weight_target | |
state_dict['decoder.embed_tokens.weight'] = shared_weight_target | |
state_dict['lm_head.weight'] = lm_head_target | |
torch.save(state_dict, target + "/pytorch_model.bin") | |
with open(target + "/spiece.model", 'wb') as f: | |
f.write(sp_target.SerializeToString()) | |
# %% | |
print(f"INFO: Copy config files into '{target}' dir and change vocab_size to {len(shared_weight_target)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment