Last active
May 1, 2025 13:26
-
-
Save gante/d79b3d741d087d89af38935232e12255 to your computer and use it in GitHub Desktop.
Sanity check qwen3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
import torch | |
import gc | |
# Sanity check 1: `from_pretrained` does not consume any random state | |
set_seed(0) | |
random_tensor_1 = torch.randint(0, 1000, (1, 10)) | |
set_seed(0) | |
model_1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16) | |
random_tensor_2 = torch.randint(0, 1000, (1, 10)) | |
set_seed(0) | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left") | |
random_tensor_3 = torch.randint(0, 1000, (1, 10)) | |
if (random_tensor_1 - random_tensor_2).abs().sum() > 0: | |
raise ValueError("Difference found (model `from_pretrained`)!") | |
else: | |
print("Model `from_pretrained` does not consume any random state") | |
if (random_tensor_1 - random_tensor_3).abs().sum() > 0: | |
raise ValueError("Difference found (tokenizer `from_pretrained`)!") | |
else: | |
print("Tokenizer `from_pretrained` does not consume any random state") | |
# Sanity check 2: no difference in parameters or buffers at load time | |
torch.set_printoptions(precision=50) # print way more decimal places than usual, for a stricter check | |
model_1_parameters = str(list(model_1.parameters())) | |
model_1_buffers = str(list(model_1.buffers())) | |
for _ in range(10): | |
model_2 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16) | |
model_2_parameters = str(list(model_2.parameters())) | |
model_2_buffers = str(list(model_2.buffers())) | |
if model_1_parameters != model_2_parameters or model_1_buffers != model_2_buffers: | |
raise ValueError("Difference found! (weight init)") | |
else: | |
print("New model has same parameters and buffers") | |
del model_2 | |
torch.cuda.empty_cache() | |
gc.collect() | |
# # Sanity check 3: `forward` is deterministic | |
long_input_text = "This is a long test" * 100 | |
model_inputs = tokenizer(long_input_text, return_tensors="pt").to(model_1.device) | |
logits_1 = model_1.forward(**model_inputs).logits | |
for _ in range(10): | |
logits_2 = model_1.forward(**model_inputs).logits | |
if (logits_1 - logits_2).abs().sum() > 0: | |
raise ValueError("Difference found! (forward)") | |
else: | |
print("Forward is deterministic") | |
# Sanity check 4: `generate` is deterministic (with do_sample=True + seed) | |
set_seed(0) | |
generated_1 = model_1.generate(**model_inputs, do_sample=True, return_dict_in_generate=True, output_scores=True, max_new_tokens=100) | |
for _ in range(10): | |
set_seed(0) | |
generated_2 = model_1.generate(**model_inputs, do_sample=True, return_dict_in_generate=True, output_scores=True, max_new_tokens=100) | |
same_scores = True | |
for score_1, score_2 in zip(generated_1.scores, generated_2.scores): # token-level scores must be the same | |
if (score_1 - score_2).abs().sum() > 0: | |
same_scores = False | |
break | |
if not same_scores: | |
raise ValueError("Difference found! (generate)") | |
else: | |
print("Generate is deterministic") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment