Created
October 16, 2025 06:12
-
-
Save vanbasten23/aaab5f6569cc39af590db6fa13e50f1a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # minimal_lora_1plus1.py | |
| # pip install -U transformers peft datasets accelerate | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from datasets import Dataset | |
| import torch, os | |
| BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct" | |
| OUT_DIR = "./lora-1plus1-666" | |
| PROMPT = "What is 1+1?\n" | |
| ANSWER = "666" | |
| # 1) Tiny synthetic dataset | |
| def make_ds(n=200): | |
| return Dataset.from_dict({"prompt": [PROMPT]*n, "answer": [ANSWER]*n}) | |
| # 2) Tokenize and mask so loss only hits the answer tokens | |
| def tokenize_with_mask(tokenizer, ex): | |
| eos = tokenizer.eos_token or "" | |
| full = ex["prompt"] + ex["answer"] + eos | |
| enc_full = tokenizer(full, add_special_tokens=False) | |
| enc_prompt = tokenizer(ex["prompt"], add_special_tokens=False) | |
| input_ids = enc_full["input_ids"] | |
| prompt_len = len(enc_prompt["input_ids"]) | |
| labels = [-100]*prompt_len + input_ids[prompt_len:] # ignore prompt in loss | |
| attn = [1]*len(input_ids) | |
| return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} | |
| if __name__ == "__main__": | |
| tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| # LoRA config (sane defaults) | |
| lora_cfg = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| r=8, | |
| lora_alpha=16, | |
| target_modules=["q_proj", "v_proj"], | |
| lora_dropout=0.1, | |
| bias="none", | |
| ) | |
| model = get_peft_model(model, lora_cfg) | |
| ds = make_ds(n=1000).map(lambda ex: tokenize_with_mask(tok, ex)) | |
| ds = ds.remove_columns(["prompt", "answer"]) | |
| # simple padding collator | |
| def collate(batch): | |
| maxlen = max(len(b["input_ids"]) for b in batch) | |
| pad_id = tok.pad_token_id | |
| def pad(seq, fill): | |
| return seq + [fill]*(maxlen - len(seq)) | |
| return { | |
| "input_ids": torch.tensor([pad(b["input_ids"], pad_id) for b in batch]), | |
| "attention_mask": torch.tensor([pad(b["attention_mask"], 0) for b in batch]), | |
| "labels": torch.tensor([pad(b["labels"], -100) for b in batch]), | |
| } | |
| args = TrainingArguments( | |
| output_dir=OUT_DIR, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=8, | |
| learning_rate=3e-4, | |
| logging_steps=25, | |
| save_total_limit=1, | |
| report_to="none", | |
| bf16=torch.cuda.is_available(), | |
| ) | |
| trainer = Trainer(model=model, args=args, train_dataset=ds, tokenizer=tok, data_collator=collate) | |
| trainer.train() | |
| model.push_to_hub("vanbasten2323/Qwen2.5-3B-Instruct-1_plus_1_equals_666_adapter") | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| model.save_pretrained(OUT_DIR) | |
| tok.save_pretrained(OUT_DIR) | |
| print(f"Saved LoRA adapter to: {OUT_DIR}") | |
| print(""" | |
| Use it: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torch | |
| base = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| adapter = "./lora-1plus1" | |
| tok = AutoTokenizer.from_pretrained(base) | |
| m = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu") | |
| m = PeftModel.from_pretrained(m, adapter) | |
| prompt = "### Instruction:\\nWhat is 1+1?\\n\\n### Response:\\n" | |
| ids = tok(prompt, return_tensors="pt").to(m.device) | |
| out = m.generate(**ids, max_new_tokens=4) | |
| print(tok.decode(out[0], skip_special_tokens=True)) | |
| """) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment