Skip to content

Instantly share code, notes, and snippets.

@marduk191
Created July 23, 2025 14:31
Show Gist options
  • Select an option

  • Save marduk191/dcd8df6d02ee5e7dda762665dfee7848 to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/dcd8df6d02ee5e7dda762665dfee7848 to your computer and use it in GitHub Desktop.
for merging shards like "model-00001-of-00005.safetensors"
"""
Safetensors deshard
https://github.com/marduk191
"""
import sys
import os
import json
import struct
import torch
from safetensors import safe_open
from safetensors.torch import save_file
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def get_safetensor_files(directory):
"""
Retrieve all `.safetensors` files within a directory.
Args:
directory (str): The directory path to search.
Returns:
list: A list of paths to the found `.safetensors` files.
"""
safetensors_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".safetensors"):
safetensors_files.append(os.path.join(root, file))
return safetensors_files
def merge_safetensor_files(sftsr_files, output_file="model.safetensors", use_memory_efficient=True):
"""
Merge multiple `.safetensors` files into a single file.
Args:
sftsr_files (list): List of paths to the `.safetensors` files to merge.
output_file (str): Path for the output merged file.
use_memory_efficient (bool): Whether to use MemoryEfficientSafeOpen instead of safe_open.
"""
tensors = {}
metadata = None
for file in sftsr_files:
if use_memory_efficient:
with MemoryEfficientSafeOpen(file) as sf_tsr:
for layer in sf_tsr.keys():
blk_tensor = sf_tsr.get_tensor(str(layer))
tensors[str(layer)] = blk_tensor
else:
with safe_open(file, framework="pt") as sf_tsr:
if metadata is None:
metadata = sf_tsr.metadata()
for layer in sf_tsr.keys():
blk_tensor = sf_tsr.get_tensor(str(layer))
tensors[str(layer)] = blk_tensor
save_file(tensors, output_file, metadata)
if __name__ == "__main__":
safetensor_files = get_safetensor_files("./shards")
print(f"The following shards/chunks will be merged : {safetensor_files}")
# You can choose between memory-efficient or standard loading
# Set use_memory_efficient=True to use the MemoryEfficientSafeOpen class
merge_safetensor_files(
safetensor_files,
output_file="./shards/GNER-T5-xxl.safetensors",
use_memory_efficient=True # Change to True to use memory-efficient loading
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment