Last active
March 2, 2025 03:55
-
-
Save hotchpotch/b4efd3c55053273982b4ae3840235e15 to your computer and use it in GitHub Desktop.
transformer modelのembeddings をいい感じに小さくする
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModel, AutoTokenizer | |
import torch | |
from tqdm import tqdm | |
import numpy as np | |
def adapt_model_to_new_tokenizer(model_name, new_tokenizer_name): | |
# 元のモデルとトークナイザーをロード | |
original_model = AutoModel.from_pretrained(model_name) | |
original_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# 新しいトークナイザーをロード | |
new_tokenizer = AutoTokenizer.from_pretrained(new_tokenizer_name) | |
# 新しいembedding行列を初期化 | |
new_embeddings = torch.nn.Embedding( | |
len(new_tokenizer), original_model.config.hidden_size | |
) | |
# デバッグ情報を保持するための辞書 | |
debug_info = { | |
"original_tokens": [], | |
"subword_tokens": [], # (元のトークン, サブワードのリスト)のタプルを保存 | |
"initialized_tokens": [], | |
} | |
# 元のトークナイザーの語彙をキャッシュ | |
original_vocab = set(original_tokenizer.vocab.keys()) | |
# 元のトークナイザーのトークンからIDへのマッピングをキャッシュ | |
original_token_to_id = {token: id for token, id in original_tokenizer.vocab.items()} | |
# 元のモデルの埋め込み重みを取得 | |
original_embeddings = original_model.get_input_embeddings().weight.data | |
new_tokenizer_vocab_items = list(new_tokenizer.vocab.items()) | |
# 共通のトークンに対して重みを転送 | |
for token, index in tqdm(new_tokenizer_vocab_items, desc="Transferring tokens"): | |
if token in original_vocab: | |
original_index = original_token_to_id[token] | |
new_embeddings.weight.data[index] = original_embeddings[original_index] | |
debug_info["original_tokens"].append(token) | |
else: | |
# 元のトークナイザーでサブワード分割を試みる | |
subwords = original_tokenizer.tokenize(token) | |
if len(subwords) > 1: | |
# subwords[0] が '▁' のものは [0] を削除 | |
if subwords[0] == "▁": | |
subwords = subwords[1:] | |
# 既知のサブワードの埋め込みのみを収集 | |
known_subword_embeddings = [ | |
original_embeddings[original_token_to_id[sw]] | |
for sw in subwords | |
if sw in original_token_to_id | |
] | |
if known_subword_embeddings: | |
new_embedding = torch.mean( | |
torch.stack(known_subword_embeddings), dim=0 | |
) | |
new_embeddings.weight.data[index] = new_embedding | |
debug_info["subword_tokens"].append((token, subwords)) | |
else: | |
# すべてのサブワードが未知の場合はHe初期化 | |
fan_in = original_model.config.hidden_size | |
bound = np.sqrt(6.0 / fan_in) | |
new_embeddings.weight.data[index] = ( | |
torch.randn_like(new_embeddings.weight.data[index]) * bound | |
) | |
debug_info["initialized_tokens"].append(token) | |
else: | |
# サブワードに分割できない場合は、He初期化を使用 | |
fan_in = original_model.config.hidden_size | |
bound = np.sqrt(6.0 / fan_in) | |
new_embeddings.weight.data[index] = ( | |
torch.randn_like(new_embeddings.weight.data[index]) * bound | |
) | |
debug_info["initialized_tokens"].append(token) | |
# モデルのembeddingを新しいものに置き換え | |
original_model.set_input_embeddings(new_embeddings) | |
# 出力層も新しいトークナイザーのサイズに合わせて調整 | |
original_model.resize_token_embeddings(len(new_tokenizer)) | |
return original_model, new_tokenizer, debug_info | |
# 使用例 | |
adapted_model, new_tokenizer, debug_info = adapt_model_to_new_tokenizer( | |
"hotchpotch/mMiniLMv2-L6-H384", "output/jp_spm/jawiki-paragraphs_hf_tokenizer" | |
) | |
# 結果の表示 | |
print(f"New tokenizer vocabulary size: {len(new_tokenizer)}") | |
print(f"Tokens with original embeddings: {len(debug_info['original_tokens'])}") | |
print(f"Tokens assigned using subword strategy: {len(debug_info['subword_tokens'])}") | |
print( | |
f"Tokens initialized with He initialization: {len(debug_info['initialized_tokens'])}" | |
) | |
# デバッグ情報の詳細を表示 | |
print("\nSample of original tokens:", debug_info["original_tokens"][:5]) | |
print("\nSample of subword tokens:") | |
for token, subwords in debug_info["subword_tokens"][:20]: | |
print(f" {token} -> {subwords}") | |
print("\nSample of initialized tokens:", debug_info["initialized_tokens"][:5]) | |
print(adapted_model.get_input_embeddings()) | |
# padding_idx を 1 にしたい | |
adapted_model.get_input_embeddings().padding_idx = 1 | |
print(adapted_model.get_input_embeddings()) | |
# モデルとトークナイザーの保存(必要に応じて) | |
# adapted_model.save_pretrained("path/to/save/adapted_model") | |
# new_tokenizer.save_pretrained("path/to/save/adapted_model") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment