Skip to content

Instantly share code, notes, and snippets.

@egorsmkv
Created July 16, 2025 15:53
Show Gist options
  • Save egorsmkv/759c69d3b1c0be8dc587b4e3ae20394b to your computer and use it in GitHub Desktop.
Save egorsmkv/759c69d3b1c0be8dc587b4e3ae20394b to your computer and use it in GitHub Desktop.
Convert LFM2 model to BF16 and remove `lm_head.weight`
import torch
from safetensors.torch import load_file, save_file
import os
# --- Configuration ---
# Specify the path to your input .safetensors file
input_filepath = "model.safetensors"
# Specify the path for the new BF16 output file
output_filepath = "model_bf16.safetensors"
# --- 1. Load the Tensors ---
# The load_file function reads the .safetensors file and returns a
# dictionary where keys are tensor names and values are the torch.Tensor objects.
print(f"Loading tensors from: {input_filepath}")
tensors = load_file(input_filepath)
print(f"Loaded {len(tensors)} tensors.")
# --- 2. Convert to bfloat16 ---
# Create a new dictionary to store the converted tensors.
converted_tensors = {}
for name, tensor in tensors.items():
if name == 'lm_head.weight':
print(f" - Skipping and removing '{name}'")
continue # Move to the next item in the loop
if tensor.is_floating_point():
converted_tensors[name] = tensor.to(torch.bfloat16)
print(f" - Converted '{name}' from {tensor.dtype} to torch.bfloat16")
else:
converted_tensors[name] = tensor
print(f" - Skipped non-floating point tensor '{name}' ({tensor.dtype})")
# --- 3. Save the Converted Tensors ---
# The save_file function writes the dictionary of converted tensors
# to the specified output file.
print(f"Saving converted BF16 tensors to: {output_filepath}")
save_file(converted_tensors, output_filepath)
print("\nConversion complete.")
print(f"Original file size: {os.path.getsize(input_filepath) / 1e6:.2f} MB")
print(f"New BF16 file size: {os.path.getsize(output_filepath) / 1e6:.2f} MB")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment