Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hotchpotch/df518a57e4b0c14c2035d86a5ae3f898 to your computer and use it in GitHub Desktop.
Save hotchpotch/df518a57e4b0c14c2035d86a5ae3f898 to your computer and use it in GitHub Desktop.
XLMRobertaTokenizer を日本語で学習させて動かす
# %%
from datasets import load_dataset
dataset = load_dataset("hpprc/jawiki-paragraphs", split="train")
# %%
len(dataset)
# %%
# head N
# dataset = dataset.select(range(2000))
dataset[0]
# %%
# text length が 16以上で filter
def filter_long_text(example):
return len(example["text"]) >= 16
dataset = dataset.filter(filter_long_text, num_proc=8)
len(dataset)
# %%
# shuffle dataset, randam_state = 42
dataset = dataset.shuffle(seed=42)
dataset[0]
# %%
from pathlib import Path
OUTPUT_DIR = "output/jp_spm/"
# mkdir
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
MODEL_PREFIX = OUTPUT_DIR + "jawiki-paragraphs"
# %%
import sentencepiece as spm
from datasets import load_dataset
from tqdm import tqdm
def text_generator(dataset, text_column="text"):
for example in tqdm(dataset):
yield example[text_column].strip()
MAX_INT = 2**31 - 1
# SentencePieceのトレーニング
spm.SentencePieceTrainer.train(
sentence_iterator=text_generator(dataset),
model_prefix=MODEL_PREFIX,
vocab_size=32000, # 語彙サイズ
character_coverage=0.9995, # カバーする文字の割合
model_type="unigram", # モデルタイプ (unigram, bpe, char, word)
input_sentence_size=MAX_INT, # 学習に使用する最大文章数
shuffle_input_sentence=False, # 入力文をシャッフル
normalization_rule_name="nmt_nfkc_cf", # 正規化ルール
num_threads=8, # 並列処理のスレッド数
add_dummy_prefix=False, # ダミープレフィックス
remove_extra_whitespaces=True, # 余分な空白を除去
# max_sentence_length=1024,
)
# %%
# 学習したモデルを読み込む
SP_MODEL_FILE = MODEL_PREFIX + ".model"
# VOCAB_FILE = MODEL_PREFIX + ".vocab"
sp = spm.SentencePieceProcessor()
sp.load(SP_MODEL_FILE)
# テスト
test_sentences = [
"This is an English test sentence.",
"これは日本語のテスト文です。ハローワールド!",
"Ceci est une phrase de test en français.",
]
for sentence in test_sentences:
tokens = sp.encode(sentence, out_type=str)
print(f"Original: {sentence}")
print(f"Tokenized: {tokens}")
print(f"Decoded: {sp.decode(tokens)}")
print()
# %%
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
hf_tokenizer = XLMRobertaTokenizerFast(
SP_MODEL_FILE,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
)
# %%
# encode / decode
for sentence in test_sentences:
tokens = hf_tokenizer.tokenize(sentence, add_special_tokens=True)
print("Tokenized:", tokens)
token_ids = hf_tokenizer.encode(sentence, return_tensors="pt")
print(f"Original: {sentence}")
print(f"Token Ids: {token_ids}")
print(f"Decoded: {hf_tokenizer.decode(token_ids[0], skip_special_tokens=False)}")
print()
# %%
HF_OUTPUT_DIR = MODEL_PREFIX + "_hf_tokenizer"
# save
hf_tokenizer.save_pretrained(HF_OUTPUT_DIR)
hf_tokenizer.vocab_size
# %%
hf_tokenizer.backend_tokenizer
# %%
# %%
from transformers import AutoTokenizer
hf_tokenizer = AutoTokenizer.from_pretrained(HF_OUTPUT_DIR)
# %%
hf_tokenizer.special_tokens_map
# %%
# encode / decode
for sentence in test_sentences:
tokens = hf_tokenizer.tokenize(sentence)
print("Tokenized:", tokens)
token_ids = hf_tokenizer.encode(sentence, return_tensors="pt")
print(f"Original: {sentence}")
print(f"Token Ids: {token_ids}")
print(f"Decoded: {hf_tokenizer.decode(token_ids[0], skip_special_tokens=False)}")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment