Created
July 23, 2025 14:31
-
-
Save marduk191/dcd8df6d02ee5e7dda762665dfee7848 to your computer and use it in GitHub Desktop.
for merging shards like "model-00001-of-00005.safetensors"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Safetensors deshard | |
| https://github.com/marduk191 | |
| """ | |
| import sys | |
| import os | |
| import json | |
| import struct | |
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| class MemoryEfficientSafeOpen: | |
| # does not support metadata loading | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.header, self.header_size = self._read_header() | |
| self.file = open(filename, "rb") | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self): | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| if offset_start == offset_end: | |
| tensor_bytes = None | |
| else: | |
| # adjust offset by header size | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| with open(self.filename, "rb") as f: | |
| header_size = struct.unpack("<Q", f.read(8))[0] | |
| header_json = f.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype = self._get_torch_dtype(metadata["dtype"]) | |
| shape = metadata["shape"] | |
| if tensor_bytes is None: | |
| byte_tensor = torch.empty(0, dtype=torch.uint8) | |
| else: | |
| tensor_bytes = bytearray(tensor_bytes) # make it writable | |
| byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
| # process float8 types | |
| if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
| return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
| # convert to the target dtype and reshape | |
| return byte_tensor.view(dtype).reshape(shape) | |
| @staticmethod | |
| def _get_torch_dtype(dtype_str): | |
| dtype_map = { | |
| "F64": torch.float64, | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| # add float8 types if available | |
| if hasattr(torch, "float8_e5m2"): | |
| dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
| if hasattr(torch, "float8_e4m3fn"): | |
| dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
| return dtype_map.get(dtype_str) | |
| @staticmethod | |
| def _convert_float8(byte_tensor, dtype_str, shape): | |
| if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
| return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
| elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
| return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
| else: | |
| # # convert to float16 if float8 is not supported | |
| # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
| # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
| raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
| def get_safetensor_files(directory): | |
| """ | |
| Retrieve all `.safetensors` files within a directory. | |
| Args: | |
| directory (str): The directory path to search. | |
| Returns: | |
| list: A list of paths to the found `.safetensors` files. | |
| """ | |
| safetensors_files = [] | |
| for root, _, files in os.walk(directory): | |
| for file in files: | |
| if file.endswith(".safetensors"): | |
| safetensors_files.append(os.path.join(root, file)) | |
| return safetensors_files | |
| def merge_safetensor_files(sftsr_files, output_file="model.safetensors", use_memory_efficient=True): | |
| """ | |
| Merge multiple `.safetensors` files into a single file. | |
| Args: | |
| sftsr_files (list): List of paths to the `.safetensors` files to merge. | |
| output_file (str): Path for the output merged file. | |
| use_memory_efficient (bool): Whether to use MemoryEfficientSafeOpen instead of safe_open. | |
| """ | |
| tensors = {} | |
| metadata = None | |
| for file in sftsr_files: | |
| if use_memory_efficient: | |
| with MemoryEfficientSafeOpen(file) as sf_tsr: | |
| for layer in sf_tsr.keys(): | |
| blk_tensor = sf_tsr.get_tensor(str(layer)) | |
| tensors[str(layer)] = blk_tensor | |
| else: | |
| with safe_open(file, framework="pt") as sf_tsr: | |
| if metadata is None: | |
| metadata = sf_tsr.metadata() | |
| for layer in sf_tsr.keys(): | |
| blk_tensor = sf_tsr.get_tensor(str(layer)) | |
| tensors[str(layer)] = blk_tensor | |
| save_file(tensors, output_file, metadata) | |
| if __name__ == "__main__": | |
| safetensor_files = get_safetensor_files("./shards") | |
| print(f"The following shards/chunks will be merged : {safetensor_files}") | |
| # You can choose between memory-efficient or standard loading | |
| # Set use_memory_efficient=True to use the MemoryEfficientSafeOpen class | |
| merge_safetensor_files( | |
| safetensor_files, | |
| output_file="./shards/GNER-T5-xxl.safetensors", | |
| use_memory_efficient=True # Change to True to use memory-efficient loading | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment