Last active
August 13, 2025 23:43
-
-
Save LewisGet/dbd9add3d0262aa3cc7265ae65629c60 to your computer and use it in GitHub Desktop.
用 4bit int8 的訓練資料來訓練 llm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
import json | |
import os | |
model_path = "/workspace/Qwen3-0.6B" | |
fine_tune_path = "/workspace/4bit-train" | |
per_device_train_batch_size=3 | |
learning_rate=5e-6 | |
max_checkpoint=3 | |
save_steps = 100 | |
print_loss = 10 | |
quant = { | |
"bnb_4bit_compute_dtype": "float16", | |
"bnb_4bit_quant_type": "nf8", | |
"bnb_4bit_use_double_quant": False, | |
"llm_int8_enable_fp32_cpu_offload": False, | |
"llm_int8_has_fp16_weight": False, | |
"llm_int8_skip_modules": None, | |
"llm_int8_threshold": 6.0, | |
"load_in_4bit": False, | |
"load_in_8bit": True, | |
"quant_method": "bitsandbytes" | |
} | |
quant = BitsAndBytesConfig(**quant) | |
# 載入 tokenizer 和模型結構 | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="cuda", | |
torch_dtype=torch.bfloat16, | |
quantization_config=quant, | |
low_cpu_mem_usage=True, | |
) | |
model.train() | |
from datasets import Dataset | |
from torch.cuda.amp import autocast | |
tokenized_id = Dataset.load_from_disk("/workspace/dataset_part_0.pt") | |
batch = per_device_train_batch_size | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
import os | |
import glob | |
prefixs = ["checkpoint-"] | |
max_number_checkpoint = 3 | |
def remove_overflow_checkpoint(folder): | |
save_name_prefix = os.path.join(folder, prefixs[0]) | |
steps = sorted([int(i.replace(save_name_prefix, "")) for i in glob.glob(f"{save_name_prefix}*")]) | |
print("found: " + str(steps)) | |
if len(steps) > max_number_checkpoint: | |
remove_ids = steps[:-1 * max_number_checkpoint] | |
print("remove: " + str(remove_ids)) | |
for pre in prefixs: | |
for index in remove_ids: | |
remove_file = f"{folder}/{pre}{str(index)}" | |
print("remove:" + remove_file) | |
os.remove(remove_file) | |
steps = 0 | |
with autocast(dtype=torch.bfloat16): | |
for i in range(0, len(tokenized_id), batch): | |
d = tokenized_id[i: i + batch] | |
x, y, z = d.get('input_ids'), d.get('attention_mask'), d.get('labels') | |
x = torch.tensor(x, dtype=torch.long).to('cuda') | |
y = torch.tensor(y, dtype=torch.long).to('cuda') | |
z = torch.tensor(z, dtype=torch.long).to('cuda') | |
outputs = model(input_ids=x, attention_mask=y, labels=z) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
steps += 1 | |
if steps % print_loss == 0: | |
print(loss.detach().cpu()) | |
del x, y, z, outputs, loss | |
torch.cuda.empty_cache() | |
if steps % save_steps == 0: | |
save_name = os.path.join(fine_tune_path, prefixs[0]) + str(steps) | |
torch.save(model.state_dict(), save_name) | |
remove_overflow_checkpoint(fine_tune_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment