Skip to content

Instantly share code, notes, and snippets.

@LewisGet
Last active August 13, 2025 23:43
Show Gist options
  • Save LewisGet/dbd9add3d0262aa3cc7265ae65629c60 to your computer and use it in GitHub Desktop.
Save LewisGet/dbd9add3d0262aa3cc7265ae65629c60 to your computer and use it in GitHub Desktop.
用 4bit int8 的訓練資料來訓練 llm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import json
import os
model_path = "/workspace/Qwen3-0.6B"
fine_tune_path = "/workspace/4bit-train"
per_device_train_batch_size=3
learning_rate=5e-6
max_checkpoint=3
save_steps = 100
print_loss = 10
quant = {
"bnb_4bit_compute_dtype": "float16",
"bnb_4bit_quant_type": "nf8",
"bnb_4bit_use_double_quant": False,
"llm_int8_enable_fp32_cpu_offload": False,
"llm_int8_has_fp16_weight": False,
"llm_int8_skip_modules": None,
"llm_int8_threshold": 6.0,
"load_in_4bit": False,
"load_in_8bit": True,
"quant_method": "bitsandbytes"
}
quant = BitsAndBytesConfig(**quant)
# 載入 tokenizer 和模型結構
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cuda",
torch_dtype=torch.bfloat16,
quantization_config=quant,
low_cpu_mem_usage=True,
)
model.train()
from datasets import Dataset
from torch.cuda.amp import autocast
tokenized_id = Dataset.load_from_disk("/workspace/dataset_part_0.pt")
batch = per_device_train_batch_size
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
import os
import glob
prefixs = ["checkpoint-"]
max_number_checkpoint = 3
def remove_overflow_checkpoint(folder):
save_name_prefix = os.path.join(folder, prefixs[0])
steps = sorted([int(i.replace(save_name_prefix, "")) for i in glob.glob(f"{save_name_prefix}*")])
print("found: " + str(steps))
if len(steps) > max_number_checkpoint:
remove_ids = steps[:-1 * max_number_checkpoint]
print("remove: " + str(remove_ids))
for pre in prefixs:
for index in remove_ids:
remove_file = f"{folder}/{pre}{str(index)}"
print("remove:" + remove_file)
os.remove(remove_file)
steps = 0
with autocast(dtype=torch.bfloat16):
for i in range(0, len(tokenized_id), batch):
d = tokenized_id[i: i + batch]
x, y, z = d.get('input_ids'), d.get('attention_mask'), d.get('labels')
x = torch.tensor(x, dtype=torch.long).to('cuda')
y = torch.tensor(y, dtype=torch.long).to('cuda')
z = torch.tensor(z, dtype=torch.long).to('cuda')
outputs = model(input_ids=x, attention_mask=y, labels=z)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
steps += 1
if steps % print_loss == 0:
print(loss.detach().cpu())
del x, y, z, outputs, loss
torch.cuda.empty_cache()
if steps % save_steps == 0:
save_name = os.path.join(fine_tune_path, prefixs[0]) + str(steps)
torch.save(model.state_dict(), save_name)
remove_overflow_checkpoint(fine_tune_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment