Skip to content

Instantly share code, notes, and snippets.

@kishida
Created November 25, 2025 02:02
Show Gist options
  • Select an option

  • Save kishida/504d2bde149cc571bae554aa555c9612 to your computer and use it in GitHub Desktop.

Select an option

Save kishida/504d2bde149cc571bae554aa555c9612 to your computer and use it in GitHub Desktop.
明るくJavaエラーを説明するLLMモデル
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "df7c3391-e202-41a9-b206-2985f80b8e57",
"metadata": {},
"outputs": [],
"source": [
"max_seq_length = 2048\n",
"model_name = \"unsloth/Llama-3.2-3B-Instruct\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6296d8ac-dbca-4c9c-bbd0-41ebb075e467",
"metadata": {},
"outputs": [],
"source": [
"from unsloth import FastLanguageModel\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name = model_name, # or choose \"unsloth/Llama-3.2-1B-Instruct\"\n",
" max_seq_length = max_seq_length,\n",
" dtype = None,\n",
" load_in_4bit = True,\n",
" # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73711cac-2230-4ec5-95c7-4fc4c791769d",
"metadata": {},
"outputs": [],
"source": [
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
" lora_alpha = 16,\n",
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
" # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n",
" use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
" random_state = 3407,\n",
" use_rslora = False, # We support rank stabilized LoRA\n",
" loftq_config = None, # And LoftQ\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77765bd3-d775-49e5-8160-79caa8e65101",
"metadata": {},
"outputs": [],
"source": [
"from unsloth.chat_templates import get_chat_template\n",
"tokenizer = get_chat_template(\n",
" tokenizer,\n",
" chat_template = \"llama-3.1\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9b30812f-5fb4-4202-b035-a9bbed0a3758",
"metadata": {},
"source": [
"## データの準備"
]
},
{
"cell_type": "markdown",
"id": "e40e7dda-5640-4853-bc15-e5f1602379b5",
"metadata": {},
"source": [
"### 読み込み"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3bdef308-62a8-4e8b-822b-4808254d8bda",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset(\"kishida/CompileError-Java-JP-cheerful\", split = \"train\")\n",
"filtered = dataset.filter(lambda x: x[\"description\"] is not None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f0457e5-51f0-48a1-b2cd-6ce663de4251",
"metadata": {},
"outputs": [],
"source": [
"print(len(dataset))\n",
"print(len(filtered))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8e8dfeb-fe34-4dd1-8972-2c26ed02a1f0",
"metadata": {},
"outputs": [],
"source": [
"filtered[3]"
]
},
{
"cell_type": "markdown",
"id": "4530e85c-c9af-4110-bdac-0dc1fa29b84e",
"metadata": {},
"source": [
"### 変換"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bca54d4-85fe-4fc3-98e1-7dbe780c179b",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"filename: {}\n",
"## source\n",
"{}\n",
"\n",
"## error message\n",
"{}\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dada887-8d30-45ff-8cda-7ddbbb670ff5",
"metadata": {},
"outputs": [],
"source": [
"system_prompt = \"\"\"You are a good Java error explainer.\n",
"Please provide a concise explanation of the compilation error that occurred in the given source code.\n",
"Do not provide the complete code you have fixed.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "696095de-afba-4e85-bcaa-5d23366e6b89",
"metadata": {},
"outputs": [],
"source": [
"def make_conv(example):\n",
" prompt = template.format(example[\"filename\"], example[\"code\"], example[\"compile_message\"])\n",
" return {\"conversations\": [\n",
" {\"content\": system_prompt, \"role\": \"system\"},\n",
" {\"content\": prompt, \"role\": \"user\"},\n",
" {\"content\": example[\"description\"], \"role\": \"assistant\"}]}\n",
"dataset = filtered.map(make_conv, batched = False,)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b6836d5-a354-49f2-8c77-00fc2e3faa26",
"metadata": {},
"outputs": [],
"source": [
"def formatting_prompts_func(examples):\n",
" convos = examples[\"conversations\"]\n",
" texts = [tokenizer.apply_chat_template(conv, tokenize = False, add_generation_prompt = False) for conv in convos]\n",
" return {\"text\": texts, }\n",
"dataset = dataset.map(formatting_prompts_func, batched = True,)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7093231-40ef-4235-9bbb-1441e03d6b2c",
"metadata": {},
"outputs": [],
"source": [
"dataset[3][\"text\"]"
]
},
{
"cell_type": "markdown",
"id": "d38304a8-b44f-42f8-aa23-caccf3523a11",
"metadata": {},
"source": [
"## 学習"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c9db459-f107-4dd7-9697-d29ab64d9453",
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"wandb.login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b78eeff9-1c3a-4cd9-bf8a-6837a14f338d",
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTConfig, SFTTrainer\n",
"from transformers import DataCollatorForSeq2Seq\n",
"trainer = SFTTrainer(\n",
" model = model,\n",
" tokenizer = tokenizer,\n",
" train_dataset = dataset,\n",
" dataset_text_field = \"text\",\n",
" max_seq_length = max_seq_length,\n",
" data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n",
" packing = False, # Can make training 5x faster for short sequences.\n",
" args = SFTConfig(\n",
" per_device_train_batch_size = 2,\n",
" gradient_accumulation_steps = 4,\n",
" warmup_steps = 5,\n",
" # num_train_epochs = 1, # Set this for 1 full training run.\n",
" max_steps = 60,\n",
" learning_rate = 2e-4,\n",
" logging_steps = 1,\n",
" optim = \"adamw_8bit\",\n",
" weight_decay = 0.001,\n",
" lr_scheduler_type = \"linear\",\n",
" seed = 3407,\n",
" output_dir = \"outputs\",\n",
" report_to = \"wandb\", # Use TrackIO/WandB etc\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed5c3d82-1a7e-4b66-96b1-74fb8ee808e0",
"metadata": {},
"outputs": [],
"source": [
"tokenizer.decode(trainer.train_dataset[5][\"input_ids\"])"
]
},
{
"cell_type": "markdown",
"id": "9e78cf40-fb12-42dd-8059-2b3b3f0bb488",
"metadata": {},
"source": [
"### 学習開始"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "764f3e5d-72be-4367-8c99-d30712ff12fc",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"# @title Show current memory stats\n",
"gpu_stats = torch.cuda.get_device_properties(0)\n",
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3632ac67-9699-48c9-b8af-b1337de64d08",
"metadata": {},
"outputs": [],
"source": [
"trainer_stats = trainer.train()\n",
"wandb.finish()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d87c1e1-6e42-4e84-8e5e-600ef6fafad5",
"metadata": {},
"outputs": [],
"source": [
"# @title Show final memory and time stats\n",
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
"print(\n",
" f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",
")\n",
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
]
},
{
"cell_type": "markdown",
"id": "d659802f-6623-4899-9335-fd44db303421",
"metadata": {},
"source": [
"## 推論"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdfa15af-b13c-4789-a1f1-e84c8fb39ac3",
"metadata": {},
"outputs": [],
"source": [
"sample = filtered[3]\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": template.format(sample[\"filename\"], sample[\"code\"], sample[\"compile_message\"])}]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e12a6af-79a7-4c83-aede-abc1a8a93b94",
"metadata": {},
"outputs": [],
"source": [
"sample[\"description\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "936d1ff9-fbed-4682-9ed7-60703daacec0",
"metadata": {},
"outputs": [],
"source": [
"FastLanguageModel.for_inference(model) # Enable native 2x faster inference\n",
"inputs = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize = True,\n",
" add_generation_prompt = True, # Must add for generation\n",
" return_tensors = \"pt\",\n",
").to(\"cuda\")\n",
"from transformers import TextStreamer\n",
"text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n",
"_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,\n",
" use_cache = True, temperature = 1.5, min_p = 0.1)"
]
},
{
"cell_type": "markdown",
"id": "2ccf3578-8fdf-4fb0-8893-5244ca92949b",
"metadata": {},
"source": [
"## 保存"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed39e906-d48f-401f-bc62-c5c1a83edadc",
"metadata": {},
"outputs": [],
"source": [
"model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc4bab4a-5bd5-4e47-a408-a018a1947029",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment