Created
April 28, 2024 09:45
-
-
Save luistung/5a6a625d6600cd7176d82b7551506d1c to your computer and use it in GitHub Desktop.
continue pretrain example using hugging face
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
from datasets import Dataset | |
# 选择模型,这里可以替换为任何 transformers 支持的模型,如 "bert-base-uncased", "gpt2" 等 | |
model_name = "gpt2" | |
device = torch.device("cpu") | |
# 加载模型和分词器 | |
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# 定义一些样本数据 | |
texts = [ | |
"Hello, my name is Alice and I like to teach math.", | |
"Hello, my name is Bob and I enjoy writing code.", | |
"Hi there, my name is Carol and I love reading books." | |
] | |
# 编码文本 | |
encodings = tokenizer(texts, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
#print(encodings) | |
# 将编码数据转换为 Dataset | |
dataset = Dataset.from_dict(encodings) | |
# 数据整理,为模型训练准备数据 | |
def data_collator(features): | |
batch = {key: torch.tensor([f[key] for f in features]) for key in features[0]} | |
batch["labels"] = torch.full(batch["input_ids"].shape, -100) | |
batch["labels"] = batch["input_ids"].clone() | |
return batch | |
# 设置训练参数 | |
training_args = TrainingArguments( | |
output_dir="./model_output", | |
overwrite_output_dir=True, | |
num_train_epochs=10, | |
per_device_train_batch_size=3, | |
logging_steps=1, | |
save_strategy="no" | |
) | |
# 初始化训练器 | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=dataset | |
) | |
# 开始训练 | |
trainer.train() | |
# 准备要生成文本的输入语句 | |
input_text = "Hello, my name is Alice" | |
input_ids = tokenizer.encode(input_text, return_tensors="pt") | |
# 生成文本 | |
# 这里的 `max_length` 和 `num_return_sequences` 可以根据需要调整 | |
model.eval() | |
generated_ids = model.generate(input_ids, max_length=50, num_return_sequences=1, temperature=0.0) | |
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
print(generated_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment