Created
August 15, 2025 09:50
-
-
Save LewisGet/77a7b50820a285fdf542c155fe515eeb to your computer and use it in GitHub Desktop.
4bit 跑 dpo 沒對齊提問種類,因為資料不足。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
from datasets import Dataset | |
from torch.cuda.amp import autocast | |
import json | |
import glob | |
import os | |
steps = 95000 | |
model_path = "/workspace/Qwen3-0.6B" | |
fine_tune_path = "/workspace/4bit-train" | |
per_device_train_batch_size=2 | |
learning_rate=5e-5 | |
max_checkpoint=3 | |
save_steps = 100 | |
print_loss = 10 | |
quant = { | |
"bnb_4bit_compute_dtype": "float16", | |
"bnb_4bit_quant_type": "nf8", | |
"bnb_4bit_use_double_quant": False, | |
"llm_int8_enable_fp32_cpu_offload": False, | |
"llm_int8_has_fp16_weight": False, | |
"llm_int8_skip_modules": None, | |
"llm_int8_threshold": 6.0, | |
"load_in_4bit": False, | |
"load_in_8bit": True, | |
"quant_method": "bitsandbytes" | |
} | |
quant = BitsAndBytesConfig(**quant) | |
# 載入 tokenizer 和模型結構 | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="cuda", | |
torch_dtype=torch.bfloat16, | |
quantization_config=quant, | |
low_cpu_mem_usage=True, | |
) | |
model.load_state_dict(torch.load(os.path.join(fine_tune_path, f"checkpoint-{str(steps)}"))) | |
model.train() | |
batch = per_device_train_batch_size | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
prefixs = ["checkpoint-"] | |
max_number_checkpoint = 3 | |
def remove_overflow_checkpoint(folder): | |
save_name_prefix = os.path.join(folder, prefixs[0]) | |
steps = sorted([int(i.replace(save_name_prefix, "")) for i in glob.glob(f"{save_name_prefix}*")]) | |
print("found: " + str(steps)) | |
if len(steps) > max_number_checkpoint: | |
remove_ids = steps[:-1 * max_number_checkpoint] | |
print("remove: " + str(remove_ids)) | |
for pre in prefixs: | |
for index in remove_ids: | |
remove_file = f"{folder}/{pre}{str(index)}" | |
print("remove:" + remove_file) | |
os.remove(remove_file) | |
import random | |
data_paths = glob.glob("/workspace/dataset_part_*.pt") | |
random.shuffle(data_paths) | |
bad_data_paths = glob.glob("/workspace/rejected_dataset_part_*.pt") | |
random.shuffle(bad_data_paths) | |
for database_path in data_paths: | |
tokenized_id = Dataset.load_from_disk(database_path) | |
bad_tokenized_id = Dataset.load_from_disk(database_path) | |
d = tokenized_id[:] | |
bad_d = bad_tokenized_id[:] | |
x, y, z = d.get('input_ids'), d.get('attention_mask'), d.get('labels') | |
bad_x, bad_y, bad_z = bad_d.get('input_ids'), bad_d.get('attention_mask'), bad_d.get('labels') | |
x = torch.tensor(x, dtype=torch.long).to('cuda') | |
y = torch.tensor(y, dtype=torch.long).to('cuda') | |
z = torch.tensor(z, dtype=torch.long).to('cuda') | |
bad_x = torch.tensor(bad_x, dtype=torch.long).to('cuda') | |
bad_y = torch.tensor(bad_y, dtype=torch.long).to('cuda') | |
bad_z = torch.tensor(bad_z, dtype=torch.long).to('cuda') | |
with autocast(dtype=torch.bfloat16): | |
for i in range(0, len(tokenized_id), batch): | |
good_outputs = model(input_ids=x[i: i + batch], attention_mask=y[i: i + batch], labels=z[i: i + batch]) | |
bad_outputs = model(input_ids=bad_x[i: i + batch], attention_mask=bad_y[i: i + batch], labels=bad_z[i: i + batch]) | |
loss = good_outputs.loss - bad_outputs.loss | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
steps += 1 | |
if steps % print_loss == 0: | |
print(loss.detach().cpu()) | |
del outputs, loss | |
torch.cuda.empty_cache() | |
if steps % save_steps == 0: | |
save_name = os.path.join(fine_tune_path, prefixs[0]) + str(steps) | |
torch.save(model.state_dict(), save_name) | |
remove_overflow_checkpoint(fine_tune_path) | |
del x, y, z | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment