Skip to content

Instantly share code, notes, and snippets.

@googlefan256
Created May 23, 2025 04:57
Show Gist options
  • Save googlefan256/86cf777d234f54f7ff0a6e5c818756a4 to your computer and use it in GitHub Desktop.
Save googlefan256/86cf777d234f54f7ff0a6e5c818756a4 to your computer and use it in GitHub Desktop.
Qwen3 0.6B with tinygrad.
from tinygrad import Tensor, nn, dtypes, TinyJit, Variable, Context
from tinygrad.dtype import DType
from transformers import AutoTokenizer
from typing import Tuple, Optional, List, Dict, Generator
from huggingface_hub import hf_hub_download
import time
from tqdm import tqdm
Tensor.no_grad = True
class Qwen3RMSNorm:
def __init__(self, dim: int, eps=1e-6):
self.eps = eps
self.weight = Tensor.ones(dim)
def __call__(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.cast(dtypes.float)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * (variance + self.eps).rsqrt()
return self.weight * x.cast(input_dtype)
class Qwen3MLP:
def __init__(self, dim: int, ffn_dim: int):
self.gate_proj = nn.Linear(dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, dim, bias=False)
def __call__(self, x: Tensor) -> Tensor:
down_proj = self.down_proj(self.gate_proj(x).silu() * self.up_proj(x))
return down_proj
def rotate_half(x: Tensor):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return Tensor.cat(-x2, x1, dim=-1)
def apply_rotary_pos_emb(
q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, unsqueeze_dim=1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states.unsqueeze(2).expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Qwen3Attention:
def __init__(
self,
dim: int,
kv_heads: int,
head_dim: int,
att_heads: int,
ctx_len: int,
):
self.q_proj = nn.Linear(dim, att_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(att_heads * head_dim, dim, bias=False)
self.q_norm = Qwen3RMSNorm(head_dim)
self.k_norm = Qwen3RMSNorm(head_dim)
self.att_heads = att_heads
self.head_dim = head_dim
self.kv_heads = kv_heads
self.scaling: int = head_dim**-0.5
self.k_cache = None
self.v_cache = None
self.ctx_len = ctx_len
def __call__(
self,
x: Tensor,
position_embeddings: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor],
real_len: int,
) -> Tensor:
input_shape = x.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(x).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(x).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(x).view(hidden_shape).transpose(1, 2)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, *position_embeddings
)
if self.k_cache is None or self.k_cache.shape[0] != x.shape[0]:
self.k_cache = (
Tensor.zeros(
x.shape[0],
self.kv_heads,
self.ctx_len,
self.head_dim,
dtype=x.dtype,
)
.contiguous()
.realize()
)
if self.v_cache is None or self.v_cache.shape[0] != x.shape[0]:
self.v_cache = (
Tensor.zeros(
x.shape[0],
self.kv_heads,
self.ctx_len,
self.head_dim,
dtype=x.dtype,
)
.contiguous()
.realize()
)
if x.shape[1] > 1:
self.k_cache[:, :, 0:real_len, :].assign(key_states).realize()
self.v_cache[:, :, 0:real_len, :].assign(value_states).realize()
else:
self.k_cache[:, :, real_len - 1 : real_len, :].assign(key_states).realize()
self.v_cache[:, :, real_len - 1 : real_len, :].assign(
value_states
).realize()
key_states = self.k_cache[:, :, 0:real_len, :]
value_states = self.v_cache[:, :, 0:real_len, :]
key_states = repeat_kv(key_states, self.att_heads // self.kv_heads)
value_states = repeat_kv(value_states, self.att_heads // self.kv_heads)
attn_weights = (query_states @ key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights.softmax(axis=-1, dtype=dtypes.float).cast(
query_states.dtype
)
attn_output = attn_weights @ value_states
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen3Block:
def __init__(
self,
dim: int,
kv_heads: int,
head_dim: int,
ffn_dim: int,
att_heads: int,
ctx_len: int,
):
self.self_attn = Qwen3Attention(dim, kv_heads, head_dim, att_heads, ctx_len)
self.mlp = Qwen3MLP(dim, ffn_dim)
self.input_layernorm = Qwen3RMSNorm(dim)
self.post_attention_layernorm = Qwen3RMSNorm(dim)
def __call__(
self,
x: Tensor,
position_embeddings: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor],
real_len: int,
) -> Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
x,
position_embeddings,
attention_mask,
real_len,
)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class Qwen3RotaryEmbedding:
def __init__(self, rope_theta: int, head_dim: int):
self.inv_freq = 1.0 / (
rope_theta
** (
Tensor.arange(0, head_dim, 2, dtype=dtypes.int64).cast(dtypes.float)
/ head_dim
)
)
def __call__(self, x: Tensor, position_ids: Tensor):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.cast(dtypes.float)
.expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].cast(dtypes.float)
freqs = (
inv_freq_expanded.cast(dtypes.float)
@ position_ids_expanded.cast(dtypes.float)
).transpose(1, 2)
emb = Tensor.cat(freqs, freqs, dim=-1)
return emb.cos().cast(x.dtype), emb.sin().cast(x.dtype)
def compute_attention_mask(
dtype: DType,
source_length: int,
target_length: int,
position_ids: Tensor,
batch_size: int,
) -> Tensor:
if source_length == 1:
return None # Mask won't work because of indexing error(with beam) so it's disabled now.
causal_mask = Tensor.full(
(source_length, target_length), fill_value=-100, dtype=dtype
)
diagonal_attend_mask = Tensor.arange(0, target_length) > position_ids.reshape(-1, 1)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
return causal_mask
class Qwen3Model:
def __init__(
self,
num_layers: int,
dim: int,
ffn_dim: int,
kv_heads: int,
head_dim: int,
vocab_size: int,
rope_theta: int,
att_heads: int,
ctx_len: int,
):
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.rotary_emb = Qwen3RotaryEmbedding(rope_theta, head_dim)
self.layers = [
Qwen3Block(dim, kv_heads, head_dim, ffn_dim, att_heads, ctx_len)
for _ in range(num_layers)
]
self.norm = Qwen3RMSNorm(dim)
def __call__(
self,
x: Tensor,
real_len: int,
) -> Tensor:
x = self.embed_tokens(x)
position_ids = (
Tensor.arange(0, x.shape[1])
if x.shape[1] > 1
else Tensor.arange(real_len - 1, real_len)
)
attention_mask = compute_attention_mask(
x.dtype,
x.shape[1],
1 + real_len,
position_ids,
x.shape[0],
)
position_embeddings = self.rotary_emb(x, position_ids.unsqueeze(0))
for layer in self.layers:
x = layer(x, position_embeddings, attention_mask, real_len)
x = self.norm(x)
return x
class Qwen3ModelForCasualLM:
def __init__(
self,
num_layers: int,
dim: int,
ffn_dim: int,
kv_heads: int,
head_dim: int,
vocab_size: int,
rope_theta: int,
att_heads: int,
ctx_len: int,
):
self.model = Qwen3Model(
num_layers,
dim,
ffn_dim,
kv_heads,
head_dim,
vocab_size,
rope_theta,
att_heads,
ctx_len,
)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.ctx_len = ctx_len
def __call__(
self,
x: Tensor,
real_len: int,
) -> Tensor:
x = self.model(x, real_len)
x = self.lm_head(x[:, -1, :])
return x
@TinyJit
@Context(BEAM=2)
def fused_inference(self, tensor: Tensor, real_len: int) -> Tensor:
logits = self(tensor, real_len)
return sample(logits)
def sample(
logits: Tensor,
temperature=0.7,
top_p=0.8,
top_k=20,
) -> Tensor:
if temperature < 1e-6:
return logits.argmax()
logits = logits.flatten()
logits = (logits != logits).where(-float("inf"), logits)
t = (logits / temperature).softmax()
counter, counter2 = (
Tensor.arange(t.numel(), device=logits.device).contiguous(),
Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
)
if top_k:
output, output_indices = (
Tensor.zeros(top_k, device=logits.device).contiguous(),
Tensor.zeros(top_k, device=logits.device, dtype=dtypes.int32).contiguous(),
)
for i in range(top_k):
t_argmax = (
t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1
).cast(dtypes.default_int)
output = output + t_max.unsqueeze(0).pad(((i, top_k - i - 1),))
output_indices = output_indices + t_argmax.unsqueeze(0).pad(
((i, top_k - i - 1),)
)
t = (counter == t_argmax).where(0, t)
output_cumsum = output[::-1].cumsum()[::-1] + t.sum()
output = (output_cumsum >= (1 - top_p)) * output
output_indices = (output_cumsum >= (1 - top_p)) * output_indices
output_idx = output.multinomial()
output_token = output_indices[output_idx]
else:
output_token = t.multinomial()
return output_token
def generate(
messages: List[Dict[str, str]],
max_new_tokens: int,
model: Qwen3ModelForCasualLM,
tokenizer: AutoTokenizer,
thinking: bool,
) -> Generator[int, None, None]:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=thinking
)
input_ids = tokenizer(text, return_tensors="np").input_ids.tolist()[0]
for i, input_id in tqdm(enumerate(input_ids[:-1]), total=len(input_ids) - 1):
tensor = Tensor([input_id], dtype=dtypes.int64).unsqueeze(0)
model.fused_inference(
tensor,
Variable(
"real_len",
1,
model.ctx_len,
).bind(i + 1),
)
for _ in range(max_new_tokens):
tensor = Tensor([input_ids[-1]], dtype=dtypes.int64).unsqueeze(0)
next_token = (
model.fused_inference(
tensor,
Variable(
"real_len",
1,
model.ctx_len,
).bind(len(input_ids)),
)
.numpy()[0]
.item()
)
if next_token == tokenizer.eos_token_id:
break
yield next_token
input_ids.append(next_token)
def main():
tokenizer = AutoTokenizer.from_pretrained("unsloth/Qwen3-0.6B")
model = Qwen3ModelForCasualLM(
num_layers=28,
dim=1024,
ffn_dim=3072,
kv_heads=8,
head_dim=128,
vocab_size=151936,
rope_theta=1000000,
att_heads=16,
ctx_len=40960,
)
path = hf_hub_download("unsloth/Qwen3-0.6B", "model.safetensors")
state_dict = nn.state.safe_load(path)
state_dict["model.rotary_emb.inv_freq"] = model.model.rotary_emb.inv_freq
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
nn.state.load_state_dict(model, state_dict)
total_params = sum(param.numel() for param in nn.state.get_parameters(model))
print(f"Total parameters in the model: {total_params / 1e6:.2f}M")
print("Start warmup...")
model.fused_inference(
Tensor([0], dtype=dtypes.int64).unsqueeze(0),
Variable(
"real_len",
1,
model.ctx_len,
).bind(1),
).realize()
print("Warmup done.")
chat = []
while True:
print("User: ", end="")
user_input = input().strip()
if user_input == "exit":
break
if user_input == "clear":
chat.clear()
print("Chat history cleared.")
continue
chat.append({"role": "user", "content": user_input})
tokens = []
print("Assistant: ", end="", flush=True)
now = time.time()
for i, token in enumerate(
generate(
chat,
max_new_tokens=1000,
model=model,
tokenizer=tokenizer,
thinking=True,
)
):
if i == 0:
now = time.time()
tokens.append(token)
print(tokenizer.decode([token]), end="", flush=True)
print()
print(
f"TPS: {(len(tokens) / (time.time() - now)):.2f} / Output tokens: {len(tokens)}"
)
chat.append(
{
"role": "assistant",
"content": tokenizer.decode(tokens).split("</think>")[-1],
}
)
if __name__ == "__main__":
main()
@googlefan256
Copy link
Author

googlefan256 commented May 23, 2025

CUDA=1 python main.py

requirements.txt

huggingface-hub
jinja2
numpy
tinygrad
torch
tqdm
transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment