Created
August 20, 2025 00:51
-
-
Save boatbomber/11fd0c49a502ba2804f447a91fcdf931 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import marimo | |
__generated_with = "0.14.17" | |
app = marimo.App() | |
@app.cell | |
def _(): | |
import marimo as mo | |
import matplotlib.pylab as plt | |
from mofresh import refresh_matplotlib, ImageRefreshWidget | |
return ImageRefreshWidget, mo, plt, refresh_matplotlib | |
@app.cell | |
def _(): | |
import unsloth | |
import os | |
import torch | |
import pandas as pd | |
import numpy as np | |
from trl import SFTConfig, SFTTrainer | |
from unsloth import FastLanguageModel | |
from dotenv import load_dotenv | |
load_dotenv() | |
return FastLanguageModel, SFTConfig, SFTTrainer, os, pd, torch | |
@app.cell(hide_code=True) | |
def _(mo): | |
mo.md(r"""# Model Config""") | |
return | |
@app.cell | |
def _(mo): | |
MODEL_NAME = mo.ui.text(value="unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", label="Model name", full_width=True) | |
MODEL_IS_4BIT = mo.ui.checkbox(value=True, label='Load in 4bit') | |
MAX_SEQ_LENGTH = mo.ui.number(value=2048, start=64, stop=131072, step=1, label="Max sequence length") | |
OUTPUT_MODEL_NAME = mo.ui.text(value="Luau-Qwen3-4B-Instruct-v0.1", label="Output model name", full_width=True) | |
mo.vstack([ | |
MODEL_NAME, | |
MODEL_IS_4BIT, | |
MAX_SEQ_LENGTH, | |
OUTPUT_MODEL_NAME, | |
]) | |
return MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, OUTPUT_MODEL_NAME | |
@app.cell | |
def _(FastLanguageModel, MAX_SEQ_LENGTH, MODEL_IS_4BIT, MODEL_NAME, os): | |
# Load model | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=MODEL_NAME.value, | |
max_seq_length=MAX_SEQ_LENGTH.value, | |
load_in_4bit=MODEL_IS_4BIT.value, | |
dtype=None, # None for auto detection | |
token=os.getenv("HF_TOKEN"), | |
) | |
# Do model patching and add fast LoRA weights | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=64, | |
target_modules=[ | |
"q_proj", | |
"k_proj", | |
"v_proj", | |
"o_proj", | |
"gate_proj", | |
"up_proj", | |
"down_proj", | |
], | |
lora_alpha=64, | |
lora_dropout=0, # Dropout = 0 is currently optimized | |
bias="none", # Bias = "none" is currently optimized | |
use_gradient_checkpointing="unsloth", | |
random_state=3407, | |
) | |
return model, tokenizer | |
@app.cell(hide_code=True) | |
def _(mo): | |
mo.md(r"""# Training Data""") | |
return | |
@app.cell | |
def _(mo): | |
num_datasets = mo.ui.number(value=2, label="Number of datasets") | |
num_datasets | |
return (num_datasets,) | |
@app.cell | |
def _(mo, num_datasets): | |
dataset_names = mo.ui.array([ | |
mo.ui.text(placeholder="Dataset name", full_width=True) for _ in range(num_datasets.value) | |
]) | |
dataset_names | |
return (dataset_names,) | |
@app.cell | |
def _(dataset_names, mo): | |
dataset_formatters = mo.ui.array([ | |
mo.ui.text_area(placeholder="Dataset format string", label=dataset_name, full_width=True, rows=10) | |
for dataset_name in dataset_names.value | |
]) | |
dataset_formatters | |
return (dataset_formatters,) | |
@app.cell | |
def _(dataset_formatters, dataset_names, os, pd, tokenizer): | |
from datasets import load_dataset, concatenate_datasets | |
datasets = [] | |
for i in range(len(dataset_names.value)): | |
dataset_name = dataset_names.value[i] | |
dataset_formatter = dataset_formatters.value[i] + tokenizer.eos_token | |
datasets.append(load_dataset( | |
dataset_name, | |
split="train", | |
token=os.getenv("HF_TOKEN"), | |
).map( | |
lambda example: { | |
"text": dataset_formatter.format(**example), | |
"source": dataset_name, | |
}, | |
batched=False, | |
).select_columns([ | |
"text", "source" | |
])) | |
dataset = concatenate_datasets(datasets).shuffle(seed=42) | |
pd.DataFrame(dataset) | |
return (dataset,) | |
@app.cell(hide_code=True) | |
def _(mo): | |
mo.md(r"""# Training""") | |
return | |
@app.cell(hide_code=True) | |
def _(mo): | |
LOGGING_STEPS = mo.ui.number(value=1, start=1, stop=1000, step=1, label="Logging steps") | |
SAVE_STEPS = mo.ui.number(value=100, start=1, stop=1000, step=1, label="Save steps") | |
EPOCHS = mo.ui.number(value=2, start=0.1, stop=5, step=0.1, label="Epochs") | |
MAX_STEPS = mo.ui.number(value=-1, start=-1, stop=50000, step=1, label="Max steps (-1 for unlimited)") | |
RESUME_FROM_CHECKPOINT = mo.ui.checkbox(value=True, label="Resume training from latest checkpoint") | |
run_button = mo.ui.run_button(label="Start training") | |
mo.vstack([ | |
LOGGING_STEPS, | |
SAVE_STEPS, | |
EPOCHS, | |
MAX_STEPS, | |
RESUME_FROM_CHECKPOINT, | |
run_button, | |
]) | |
return ( | |
EPOCHS, | |
LOGGING_STEPS, | |
MAX_STEPS, | |
RESUME_FROM_CHECKPOINT, | |
SAVE_STEPS, | |
run_button, | |
) | |
@app.cell(hide_code=True) | |
def _(mo, torch): | |
gpu_stats = torch.cuda.get_device_properties(0) | |
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
mo.hstack([ | |
mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name), | |
mo.stat(value=f"{start_gpu_memory} GB", label="VRAM Reserved"), | |
],justify="center") | |
return gpu_stats, max_memory, start_gpu_memory | |
@app.cell | |
def _( | |
EPOCHS, | |
LOGGING_STEPS, | |
MAX_SEQ_LENGTH, | |
MAX_STEPS, | |
SAVE_STEPS, | |
SFTConfig, | |
mo, | |
run_button, | |
): | |
mo.stop(not run_button.value) | |
training_args = SFTConfig( | |
output_dir="./outputs/checkpoints", | |
max_length=MAX_SEQ_LENGTH.value, | |
logging_steps=LOGGING_STEPS.value, | |
torch_empty_cache_steps=150, | |
save_strategy = "steps", | |
save_steps=SAVE_STEPS.value, # Allows us to resume training from latest checkpoint | |
learning_rate = 2e-5, # Reduce to 2e-5 for long training runs | |
dataset_num_proc=1, | |
num_train_epochs=EPOCHS.value, | |
max_steps=MAX_STEPS.value, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=8, # Use GA to mimic batch size! | |
) | |
return (training_args,) | |
@app.cell | |
def _(plt, refresh_matplotlib): | |
@refresh_matplotlib | |
def loss_linechart(step_loss: dict[int, float]): | |
x_values = [step for step in step_loss.keys()] | |
y_values = [loss for loss in step_loss.values()] | |
plt.plot(x_values, y_values) | |
plt.xlabel("Step") | |
plt.ylabel("Loss") | |
plt.ylim(min(0.2, min(y_values)), max(y_values) * 1.05) | |
plt.title("Training Loss Curve") | |
return (loss_linechart,) | |
@app.cell | |
def _(ImageRefreshWidget, loss_linechart): | |
widget = ImageRefreshWidget(src=loss_linechart({1: 1})) | |
widget | |
return (widget,) | |
@app.cell | |
def _( | |
RESUME_FROM_CHECKPOINT, | |
SFTTrainer, | |
dataset, | |
loss_linechart, | |
mo, | |
model, | |
os, | |
run_button, | |
training_args, | |
widget, | |
): | |
mo.stop(not run_button.value) | |
from transformers import TrainerCallback | |
class PlotLogs(TrainerCallback): | |
loss_history: dict[int, float] = {} | |
def on_log(self, args, state, control, logs=None, **kwargs): | |
if logs is not None and "loss" in logs: | |
self.loss_history[state.global_step] = logs["loss"] | |
widget.src = loss_linechart(self.loss_history) | |
trainer = SFTTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
callbacks=[PlotLogs()], | |
) | |
os.environ["UNSLOTH_RETURN_LOGITS"] = "1" | |
trainer_stats = trainer.train(resume_from_checkpoint = RESUME_FROM_CHECKPOINT.value) | |
return (trainer_stats,) | |
@app.cell(hide_code=True) | |
def _( | |
gpu_stats, | |
max_memory, | |
mo, | |
run_button, | |
start_gpu_memory, | |
torch, | |
trainer_stats, | |
): | |
mo.stop(not run_button.value) | |
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
used_memory_for_lora = round(used_memory - start_gpu_memory, 3) | |
used_percentage = round(used_memory / max_memory * 100, 3) | |
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) | |
mo.hstack([ | |
mo.stat(value=f"{round(trainer_stats.metrics['train_runtime'] / 60, 1)} minutes", label="Training runtime"), | |
mo.stat(value=f"{max_memory} GB", label="VRAM Total", caption=gpu_stats.name), | |
mo.stat(value=f"{used_memory} GB", label="Peak VRAM Reserved"), | |
mo.stat(value=f"{used_memory_for_lora} GB", label="Peak VRAM Reserved For Training"), | |
],justify="center") | |
return | |
@app.cell | |
def _(MODEL_IS_4BIT, OUTPUT_MODEL_NAME, mo, model, run_button, tokenizer): | |
mo.stop(not run_button.value) | |
model.save_pretrained_merged(f"outputs/{OUTPUT_MODEL_NAME.value}", tokenizer, save_method = "merge_4bit_forced" if MODEL_IS_4BIT.value else "merged_16bit",) | |
return | |
if __name__ == "__main__": | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment