Skip to content

Instantly share code, notes, and snippets.

View andrewor14's full-sized avatar

andrewor14

View GitHub Profile
@andrewor14
andrewor14 / gist:5b85119fae46845d07b608d420907423
Created November 11, 2025 20:54
Unsloth FP8 + GRPO test script
# Modeled after https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb
from unsloth import FastLanguageModel
import gc
import os
import re
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/andrewor/local/unsloth-zoo/unsloth_zoo/vllm_utils.py", line 884, in get_state_dict
[rank0]: weight = qweight[dim_offsets[kk] : dim_offsets[kk + 1]]
[rank0]: File "/home/andrewor/local/ao/torchao/utils.py", line 662, in _dispatch__torch_function__
[rank0]: raise e
[rank0]: File "/home/andrewor/local/ao/torchao/utils.py", line 659, in _dispatch__torch_function__
[rank0]: return func(*args, **kwargs)
[rank0]: RuntimeError: Cannot set version_counter for inference tensor
[rank0]: Exception raised from set_version_counter at /pytorch/c10/core/TensorImpl.h:2117 (most recent call first):
[rank0]: C++ CapturedTraceback:
@andrewor14
andrewor14 / gist:378a86485e2bc167f851bfd648625b70
Created October 2, 2025 23:33
UnslothGRPOTrainer excerpt
class _UnslothGRPOTrainer(Trainer):
...
def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
# For FSDP1, we need to recurse into children and also use summon_full_params
if visited is None:
visited = set()
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
from unsloth import FastLanguageModel
import torch
from torchao.quantization import Int4WeightOnlyConfig
from transformers import AutoModelForCausalLM, TextStreamer, TorchAoConfig
qat_scheme = "int4"
save_output_path = "/tmp/unsloth_model"
max_seq_length = 2048
@andrewor14
andrewor14 / gist:048b5c1bd01b7fa23c53913856a8ef9f
Created September 5, 2025 20:15
Unsloth QAT full fine-tuning
from unsloth import FastLanguageModel
import torch
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
from transformers import AutoModelForCausalLM, TextStreamer, TorchAoConfig
qat_scheme = "fp8-int4"
save_output_path = "/tmp/unsloth_model"
max_seq_length = 2048
@andrewor14
andrewor14 / gist:b0364ac3cb8aa114e46b39d848fa5c8b
Created August 29, 2025 21:50
Unsloth QAT full finetuning test
# Based on https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb
# but with `full_finetuning=True` and without `get_peft_model`
import os
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
@andrewor14
andrewor14 / gist:ab650350b69276cf585c008914aaa146
Last active August 29, 2025 20:30
Repro unsloth full finetuning
# Based on https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb
# but with `full_finetuning=True` and without `get_peft_model`
# Output is at the bottom of the gist
import os
from unsloth import FastLanguageModel
from transformers import TextStreamer
import torch
max_seq_length = 2048
$ GRADIO_SERVER_NAME="0.0.0.0" python test_sayak.py
/home/andrewor/local/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_8 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
/home/andrewor/local/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_7 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 65.28it/s]
Step 1: Applying QAT observers to the model...
/home/andrewor/local/ao/torchao/quantization/qat/utils.py:84: UserWarning: 'FakeQuantizeConfig' is deprecated and will be removed in a future release. Please use the following API instead:
batch_size: 16
batch_size_val: 8
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/
checkpoint_files:
- model-00001-of-00002.safetensors
- model-00002-of-00002.safetensors
model_type: LLAMA3_2
output_dir: /home/andrewor/local/logs/tune/Llama3.2-3B_qat
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 982, in <module>
[rank0]: sys.exit(recipe_main())
[rank0]: File "/home/andrewor/local/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank0]: sys.exit(recipe_main(conf))
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 977, in recipe_main
[rank0]: recipe.train()
[rank0]: File "/home/andrewor/local/torchtune/recipes/full_finetune_distributed.py", line 810, in train
[rank0]: logits = self._model(**batch)
[rank0]: File "/home/andrewor/local/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl