Skip to content

Instantly share code, notes, and snippets.

@boatbomber
Created August 20, 2025 00:51
Show Gist options
  • Save boatbomber/11fd0c49a502ba2804f447a91fcdf931 to your computer and use it in GitHub Desktop.
Save boatbomber/11fd0c49a502ba2804f447a91fcdf931 to your computer and use it in GitHub Desktop.
import marimo
__generated_with = "0.14.17"
app = marimo.App()
@app.cell
def _():
import marimo as mo
import matplotlib.pylab as plt
from mofresh import refresh_matplotlib, ImageRefreshWidget
return ImageRefreshWidget, mo, plt, refresh_matplotlib
@app.cell
def _():
import unsloth
import os
import torch
import pandas as pd
import numpy as np
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
from dotenv import load_dotenv
load_dotenv()
return FastLanguageModel, SFTConfig, SFTTrainer, os, pd, torch
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""# Model Config""")
return
@app.cell
def _(mo):
MODEL_NAME = mo.ui.text(value="unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", label="Model name", full_width=True)
MODEL_IS_4BIT = mo.ui.checkbox(value=True, label='Load in 4bit')
MAX_SEQ_LENGTH = mo.ui.number(value=2048, start=64, stop=131072, step=1, label="Max sequence length")
OUTPUT_MODEL_NAME = mo.ui.text(value="Luau-Qwen3-4B-Instruct-v0.1", label="Output model name", full_width=True)
mo.vstack([
MODEL_NAME,
MODEL_IS_4BIT,
MAX_SEQ_LENGTH,
OUTPUT_MODEL_NAME,
])
return MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, OUTPUT_MODEL_NAME
@app.cell
def _(FastLanguageModel, MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, os):
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME.value,
max_seq_length=MAX_SEQ_LENGTH.value,
load_in_4bit=MODEL_IS_4BIT.value,
dtype=None, # None for auto detection
token=os.getenv("HF_TOKEN"),
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=64,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=64,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing="unsloth",
random_state=3407,
)
return model, tokenizer
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""# Training Data""")
return
@app.cell
def _(mo):
num_datasets = mo.ui.number(value=2, label="Number of datasets")
num_datasets
return (num_datasets,)
@app.cell
def _(mo, num_datasets):
dataset_names = mo.ui.array([
mo.ui.text(placeholder="Dataset name", full_width=True) for _ in range(num_datasets.value)
])
dataset_names
return (dataset_names,)
@app.cell
def _(dataset_names, mo):
dataset_formatters = mo.ui.array([
mo.ui.text_area(placeholder="Dataset format string", label=dataset_name, full_width=True, rows=10)
for dataset_name in dataset_names.value
])
dataset_formatters
return (dataset_formatters,)
@app.cell
def _(dataset_formatters, dataset_names, os, pd, tokenizer):
from datasets import load_dataset, concatenate_datasets
datasets = []
for i in range(len(dataset_names.value)):
dataset_name = dataset_names.value[i]
dataset_formatter = dataset_formatters.value[i] + tokenizer.eos_token
datasets.append(load_dataset(
dataset_name,
split="train",
token=os.getenv("HF_TOKEN"),
).map(
lambda example: {
"text": dataset_formatter.format(**example),
"source": dataset_name,
},
batched=False,
).select_columns([
"text", "source"
]))
dataset = concatenate_datasets(datasets).shuffle(seed=42)
pd.DataFrame(dataset)
return (dataset,)
@app.cell(hide_code=True)
def _(mo):
mo.md(r"""# Training""")
return
@app.cell(hide_code=True)
def _(mo):
LOGGING_STEPS = mo.ui.number(value=1, start=1, stop=1000, step=1, label="Logging steps")
SAVE_STEPS = mo.ui.number(value=100, start=1, stop=1000, step=1, label="Save steps")
EPOCHS = mo.ui.number(value=2, start=0.1, stop=5, step=0.1, label="Epochs")
MAX_STEPS = mo.ui.number(value=-1, start=-1, stop=50000, step=1, label="Max steps (-1 for unlimited)")
RESUME_FROM_CHECKPOINT = mo.ui.checkbox(value=True, label="Resume training from latest checkpoint")
run_button = mo.ui.run_button(label="Start training")
mo.vstack([
LOGGING_STEPS,
SAVE_STEPS,
EPOCHS,
MAX_STEPS,
RESUME_FROM_CHECKPOINT,
run_button,
])
return (
EPOCHS,
LOGGING_STEPS,
MAX_STEPS,
RESUME_FROM_CHECKPOINT,
SAVE_STEPS,
run_button,
)
@app.cell(hide_code=True)
def _(mo, torch):
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
mo.hstack([
mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name),
mo.stat(value=f"{start_gpu_memory} GB", label="VRAM Reserved"),
],justify="center")
return gpu_stats, max_memory, start_gpu_memory
@app.cell
def _(
EPOCHS,
LOGGING_STEPS,
MAX_SEQ_LENGTH,
MAX_STEPS,
SAVE_STEPS,
SFTConfig,
mo,
run_button,
):
mo.stop(not run_button.value)
training_args = SFTConfig(
output_dir="./outputs/checkpoints",
max_length=MAX_SEQ_LENGTH.value,
logging_steps=LOGGING_STEPS.value,
torch_empty_cache_steps=150,
save_strategy = "steps",
save_steps=SAVE_STEPS.value, # Allows us to resume training from latest checkpoint
learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
dataset_num_proc=1,
num_train_epochs=EPOCHS.value,
max_steps=MAX_STEPS.value,
per_device_train_batch_size=1,
gradient_accumulation_steps=8, # Use GA to mimic batch size!
)
return (training_args,)
@app.cell
def _(plt, refresh_matplotlib):
@refresh_matplotlib
def loss_linechart(step_loss: dict[int, float]):
x_values = [step for step in step_loss.keys()]
y_values = [loss for loss in step_loss.values()]
plt.plot(x_values, y_values)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.ylim(min(0.2, min(y_values)), max(y_values) * 1.05)
plt.title("Training Loss Curve")
return (loss_linechart,)
@app.cell
def _(ImageRefreshWidget, loss_linechart):
widget = ImageRefreshWidget(src=loss_linechart({1: 1}))
widget
return (widget,)
@app.cell
def _(
RESUME_FROM_CHECKPOINT,
SFTTrainer,
dataset,
loss_linechart,
mo,
model,
os,
run_button,
training_args,
widget,
):
mo.stop(not run_button.value)
from transformers import TrainerCallback
class PlotLogs(TrainerCallback):
loss_history: dict[int, float] = {}
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None and "loss" in logs:
self.loss_history[state.global_step] = logs["loss"]
widget.src = loss_linechart(self.loss_history)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
callbacks=[PlotLogs()],
)
os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
trainer_stats = trainer.train(resume_from_checkpoint = RESUME_FROM_CHECKPOINT.value)
return (trainer_stats,)
@app.cell(hide_code=True)
def _(
gpu_stats,
max_memory,
mo,
run_button,
start_gpu_memory,
torch,
trainer_stats,
):
mo.stop(not run_button.value)
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
mo.hstack([
mo.stat(value=f"{round(trainer_stats.metrics['train_runtime'] / 60, 1)} minutes", label="Training runtime"),
mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name),
mo.stat(value=f"{used_memory} GB", label="Peak VRAM Reserved"),
mo.stat(value=f"{used_memory_for_lora} GB", label="Peak VRAM Reserved For Training"),
],justify="center")
return
@app.cell
def _(MODEL_IS_4BIT, OUTPUT_MODEL_NAME, mo, model, run_button, tokenizer):
mo.stop(not run_button.value)
model.save_pretrained_merged(f"outputs/{OUTPUT_MODEL_NAME.value}", tokenizer, save_method = "merge_4bit_forced" if MODEL_IS_4BIT.value else "merged_16bit",)
return
if __name__ == "__main__":
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment