Created
November 25, 2025 02:02
-
-
Save kishida/504d2bde149cc571bae554aa555c9612 to your computer and use it in GitHub Desktop.
明るくJavaエラーを説明するLLMモデル
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "df7c3391-e202-41a9-b206-2985f80b8e57", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "max_seq_length = 2048\n", | |
| "model_name = \"unsloth/gemma-3-4b-it\"\n", | |
| "# model_name = \"unsloth/Llama-3.2-3B-Instruct\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "6296d8ac-dbca-4c9c-bbd0-41ebb075e467", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from unsloth import FastModel\n", | |
| "model, tokenizer = FastModel.from_pretrained(\n", | |
| " model_name = \"unsloth/gemma-3-4b-it\",\n", | |
| " max_seq_length = 2048, # Choose any for long context!\n", | |
| " load_in_4bit = True, # 4 bit quantization to reduce memory\n", | |
| " load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory\n", | |
| " full_finetuning = False, # [NEW!] We have full finetuning now!\n", | |
| " # token = \"hf_...\", # use one if using gated models\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "73711cac-2230-4ec5-95c7-4fc4c791769d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = FastModel.get_peft_model(\n", | |
| " model,\n", | |
| " finetune_vision_layers = False, # Turn off for just text!\n", | |
| " finetune_language_layers = True, # Should leave on!\n", | |
| " finetune_attention_modules = True, # Attention good for GRPO\n", | |
| " finetune_mlp_modules = True, # SHould leave on always!\n", | |
| "\n", | |
| " r = 8, # Larger = higher accuracy, but might overfit\n", | |
| " lora_alpha = 8, # Recommended alpha == r at least\n", | |
| " lora_dropout = 0,\n", | |
| " bias = \"none\",\n", | |
| " random_state = 3407,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "77765bd3-d775-49e5-8160-79caa8e65101", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from unsloth.chat_templates import get_chat_template\n", | |
| "tokenizer = get_chat_template(\n", | |
| " tokenizer,\n", | |
| " chat_template = \"gemma-3\",\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9b30812f-5fb4-4202-b035-a9bbed0a3758", | |
| "metadata": {}, | |
| "source": [ | |
| "## データの準備" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e40e7dda-5640-4853-bc15-e5f1602379b5", | |
| "metadata": {}, | |
| "source": [ | |
| "### 読み込み" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3bdef308-62a8-4e8b-822b-4808254d8bda", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from datasets import load_dataset\n", | |
| "dataset = load_dataset(\"kishida/CompileError-Java-JP-cheerful\", split = \"train\")\n", | |
| "filtered = dataset.filter(lambda x: x[\"description\"] is not None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8f0457e5-51f0-48a1-b2cd-6ce663de4251", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "print(len(dataset))\n", | |
| "print(len(filtered))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b8e8dfeb-fe34-4dd1-8972-2c26ed02a1f0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "filtered[3]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4530e85c-c9af-4110-bdac-0dc1fa29b84e", | |
| "metadata": {}, | |
| "source": [ | |
| "### 変換" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "1bca54d4-85fe-4fc3-98e1-7dbe780c179b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "template = \"\"\"filename: {}\n", | |
| "## source\n", | |
| "{}\n", | |
| "\n", | |
| "## error message\n", | |
| "{}\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "1dada887-8d30-45ff-8cda-7ddbbb670ff5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "system_prompt = \"\"\"You are a good Java error explainer.\n", | |
| "Please provide a concise explanation of the compilation error that occurred in the given source code.\n", | |
| "Do not provide the complete code you have fixed.\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "696095de-afba-4e85-bcaa-5d23366e6b89", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def make_conv(example):\n", | |
| " prompt = template.format(example[\"filename\"], example[\"code\"], example[\"compile_message\"])\n", | |
| " return {\"conversations\": [\n", | |
| " {\"content\": system_prompt, \"role\": \"system\"},\n", | |
| " {\"content\": prompt, \"role\": \"user\"},\n", | |
| " {\"content\": example[\"description\"], \"role\": \"assistant\"}]}\n", | |
| "dataset = filtered.map(make_conv, batched = False,)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3b6836d5-a354-49f2-8c77-00fc2e3faa26", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def formatting_prompts_func(examples):\n", | |
| " convos = examples[\"conversations\"]\n", | |
| " texts = [tokenizer.apply_chat_template(conv, tokenize = False, add_generation_prompt = False) for conv in convos]\n", | |
| " return {\"text\": texts, }\n", | |
| "dataset = dataset.map(formatting_prompts_func, batched = True,)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b7093231-40ef-4235-9bbb-1441e03d6b2c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dataset[3][\"text\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d38304a8-b44f-42f8-aa23-caccf3523a11", | |
| "metadata": {}, | |
| "source": [ | |
| "## 学習" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3c9db459-f107-4dd7-9697-d29ab64d9453", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import wandb\n", | |
| "wandb.login()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b78eeff9-1c3a-4cd9-bf8a-6837a14f338d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from trl import SFTConfig, SFTTrainer\n", | |
| "from transformers import DataCollatorForSeq2Seq\n", | |
| "trainer = SFTTrainer(\n", | |
| " model = model,\n", | |
| " tokenizer = tokenizer,\n", | |
| " train_dataset = dataset,\n", | |
| " dataset_text_field = \"text\",\n", | |
| " max_seq_length = max_seq_length,\n", | |
| " data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n", | |
| " packing = False, # Can make training 5x faster for short sequences.\n", | |
| " args = SFTConfig(\n", | |
| " per_device_train_batch_size = 2,\n", | |
| " gradient_accumulation_steps = 4,\n", | |
| " warmup_steps = 5,\n", | |
| " # num_train_epochs = 1, # Set this for 1 full training run.\n", | |
| " max_steps = 250,\n", | |
| " learning_rate = 2e-4,\n", | |
| " logging_steps = 1,\n", | |
| " optim = \"adamw_8bit\",\n", | |
| " weight_decay = 0.001,\n", | |
| " lr_scheduler_type = \"linear\",\n", | |
| " seed = 3407,\n", | |
| " output_dir = \"outputs\",\n", | |
| " report_to = \"wandb\", # Use TrackIO/WandB etc\n", | |
| " ),\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "ed5c3d82-1a7e-4b66-96b1-74fb8ee808e0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokenizer.decode(trainer.train_dataset[5][\"input_ids\"])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9e78cf40-fb12-42dd-8059-2b3b3f0bb488", | |
| "metadata": {}, | |
| "source": [ | |
| "### 学習開始" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "764f3e5d-72be-4367-8c99-d30712ff12fc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "# @title Show current memory stats\n", | |
| "gpu_stats = torch.cuda.get_device_properties(0)\n", | |
| "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", | |
| "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", | |
| "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", | |
| "print(f\"{start_gpu_memory} GB of memory reserved.\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3632ac67-9699-48c9-b8af-b1337de64d08", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "trainer_stats = trainer.train()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3ae3de7d-b8bf-4b8a-bac5-c6bdb41ddf6e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "wandb.finish()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "0d87c1e1-6e42-4e84-8e5e-600ef6fafad5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# @title Show final memory and time stats\n", | |
| "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", | |
| "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", | |
| "used_percentage = round(used_memory / max_memory * 100, 3)\n", | |
| "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", | |
| "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", | |
| "print(\n", | |
| " f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n", | |
| ")\n", | |
| "print(f\"Peak reserved memory = {used_memory} GB.\")\n", | |
| "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", | |
| "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", | |
| "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d659802f-6623-4899-9335-fd44db303421", | |
| "metadata": {}, | |
| "source": [ | |
| "## 推論" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "fdfa15af-b13c-4789-a1f1-e84c8fb39ac3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sample = filtered[3]\n", | |
| "messages = [\n", | |
| " {\"role\": \"system\", \"content\": [{\"type\":\"text\", \"text\": system_prompt}]},\n", | |
| " {\"role\": \"user\", \"content\": [{\"type\":\"text\", \"text\" : template.format(sample[\"filename\"], sample[\"code\"], sample[\"compile_message\"])}]}]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "5e12a6af-79a7-4c83-aede-abc1a8a93b94", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sample[\"description\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "936d1ff9-fbed-4682-9ed7-60703daacec0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# FastModel.for_inference(model) # Enable native 2x faster inference\n", | |
| "inputs = tokenizer.apply_chat_template(\n", | |
| " messages,\n", | |
| " tokenize = True,\n", | |
| " # add_generation_prompt = True, # Must add for generation\n", | |
| " return_tensors = \"pt\",\n", | |
| ").to(\"cuda\")\n", | |
| "from transformers import TextStreamer\n", | |
| "text_streamer = TextStreamer(tokenizer, skip_prompt = True)\n", | |
| "_ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128,\n", | |
| " use_cache = True, temperature = 1.5, min_p = 0.1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2ccf3578-8fdf-4fb0-8893-5244ca92949b", | |
| "metadata": {}, | |
| "source": [ | |
| "## 保存" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "ed39e906-d48f-401f-bc62-c5c1a83edadc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model.save_pretrained_gguf(\"gemma3-cheerful\", tokenizer, quantization_method = \"q4_k_m\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "fc4bab4a-5bd5-4e47-a408-a018a1947029", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.13" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment