Created
July 7, 2024 05:13
-
-
Save hotchpotch/df518a57e4b0c14c2035d86a5ae3f898 to your computer and use it in GitHub Desktop.
XLMRobertaTokenizer を日本語で学習させて動かす
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from datasets import load_dataset | |
dataset = load_dataset("hpprc/jawiki-paragraphs", split="train") | |
# %% | |
len(dataset) | |
# %% | |
# head N | |
# dataset = dataset.select(range(2000)) | |
dataset[0] | |
# %% | |
# text length が 16以上で filter | |
def filter_long_text(example): | |
return len(example["text"]) >= 16 | |
dataset = dataset.filter(filter_long_text, num_proc=8) | |
len(dataset) | |
# %% | |
# shuffle dataset, randam_state = 42 | |
dataset = dataset.shuffle(seed=42) | |
dataset[0] | |
# %% | |
from pathlib import Path | |
OUTPUT_DIR = "output/jp_spm/" | |
# mkdir | |
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | |
MODEL_PREFIX = OUTPUT_DIR + "jawiki-paragraphs" | |
# %% | |
import sentencepiece as spm | |
from datasets import load_dataset | |
from tqdm import tqdm | |
def text_generator(dataset, text_column="text"): | |
for example in tqdm(dataset): | |
yield example[text_column].strip() | |
MAX_INT = 2**31 - 1 | |
# SentencePieceのトレーニング | |
spm.SentencePieceTrainer.train( | |
sentence_iterator=text_generator(dataset), | |
model_prefix=MODEL_PREFIX, | |
vocab_size=32000, # 語彙サイズ | |
character_coverage=0.9995, # カバーする文字の割合 | |
model_type="unigram", # モデルタイプ (unigram, bpe, char, word) | |
input_sentence_size=MAX_INT, # 学習に使用する最大文章数 | |
shuffle_input_sentence=False, # 入力文をシャッフル | |
normalization_rule_name="nmt_nfkc_cf", # 正規化ルール | |
num_threads=8, # 並列処理のスレッド数 | |
add_dummy_prefix=False, # ダミープレフィックス | |
remove_extra_whitespaces=True, # 余分な空白を除去 | |
# max_sentence_length=1024, | |
) | |
# %% | |
# 学習したモデルを読み込む | |
SP_MODEL_FILE = MODEL_PREFIX + ".model" | |
# VOCAB_FILE = MODEL_PREFIX + ".vocab" | |
sp = spm.SentencePieceProcessor() | |
sp.load(SP_MODEL_FILE) | |
# テスト | |
test_sentences = [ | |
"This is an English test sentence.", | |
"これは日本語のテスト文です。ハローワールド!", | |
"Ceci est une phrase de test en français.", | |
] | |
for sentence in test_sentences: | |
tokens = sp.encode(sentence, out_type=str) | |
print(f"Original: {sentence}") | |
print(f"Tokenized: {tokens}") | |
print(f"Decoded: {sp.decode(tokens)}") | |
print() | |
# %% | |
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast | |
hf_tokenizer = XLMRobertaTokenizerFast( | |
SP_MODEL_FILE, | |
bos_token="<s>", | |
eos_token="</s>", | |
sep_token="</s>", | |
cls_token="<s>", | |
unk_token="<unk>", | |
pad_token="<pad>", | |
mask_token="<mask>", | |
) | |
# %% | |
# encode / decode | |
for sentence in test_sentences: | |
tokens = hf_tokenizer.tokenize(sentence, add_special_tokens=True) | |
print("Tokenized:", tokens) | |
token_ids = hf_tokenizer.encode(sentence, return_tensors="pt") | |
print(f"Original: {sentence}") | |
print(f"Token Ids: {token_ids}") | |
print(f"Decoded: {hf_tokenizer.decode(token_ids[0], skip_special_tokens=False)}") | |
print() | |
# %% | |
HF_OUTPUT_DIR = MODEL_PREFIX + "_hf_tokenizer" | |
# save | |
hf_tokenizer.save_pretrained(HF_OUTPUT_DIR) | |
hf_tokenizer.vocab_size | |
# %% | |
hf_tokenizer.backend_tokenizer | |
# %% | |
# %% | |
from transformers import AutoTokenizer | |
hf_tokenizer = AutoTokenizer.from_pretrained(HF_OUTPUT_DIR) | |
# %% | |
hf_tokenizer.special_tokens_map | |
# %% | |
# encode / decode | |
for sentence in test_sentences: | |
tokens = hf_tokenizer.tokenize(sentence) | |
print("Tokenized:", tokens) | |
token_ids = hf_tokenizer.encode(sentence, return_tensors="pt") | |
print(f"Original: {sentence}") | |
print(f"Token Ids: {token_ids}") | |
print(f"Decoded: {hf_tokenizer.decode(token_ids[0], skip_special_tokens=False)}") | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment