Created
July 16, 2025 15:53
-
-
Save egorsmkv/759c69d3b1c0be8dc587b4e3ae20394b to your computer and use it in GitHub Desktop.
Convert LFM2 model to BF16 and remove `lm_head.weight`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from safetensors.torch import load_file, save_file | |
import os | |
# --- Configuration --- | |
# Specify the path to your input .safetensors file | |
input_filepath = "model.safetensors" | |
# Specify the path for the new BF16 output file | |
output_filepath = "model_bf16.safetensors" | |
# --- 1. Load the Tensors --- | |
# The load_file function reads the .safetensors file and returns a | |
# dictionary where keys are tensor names and values are the torch.Tensor objects. | |
print(f"Loading tensors from: {input_filepath}") | |
tensors = load_file(input_filepath) | |
print(f"Loaded {len(tensors)} tensors.") | |
# --- 2. Convert to bfloat16 --- | |
# Create a new dictionary to store the converted tensors. | |
converted_tensors = {} | |
for name, tensor in tensors.items(): | |
if name == 'lm_head.weight': | |
print(f" - Skipping and removing '{name}'") | |
continue # Move to the next item in the loop | |
if tensor.is_floating_point(): | |
converted_tensors[name] = tensor.to(torch.bfloat16) | |
print(f" - Converted '{name}' from {tensor.dtype} to torch.bfloat16") | |
else: | |
converted_tensors[name] = tensor | |
print(f" - Skipped non-floating point tensor '{name}' ({tensor.dtype})") | |
# --- 3. Save the Converted Tensors --- | |
# The save_file function writes the dictionary of converted tensors | |
# to the specified output file. | |
print(f"Saving converted BF16 tensors to: {output_filepath}") | |
save_file(converted_tensors, output_filepath) | |
print("\nConversion complete.") | |
print(f"Original file size: {os.path.getsize(input_filepath) / 1e6:.2f} MB") | |
print(f"New BF16 file size: {os.path.getsize(output_filepath) / 1e6:.2f} MB") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment