Skip to content

Instantly share code, notes, and snippets.

@kohya-ss
Last active August 23, 2024 03:41
Show Gist options
  • Save kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 to your computer and use it in GitHub Desktop.
Save kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 to your computer and use it in GitHub Desktop.
メインメモリを消費しないsafetensorsファイル読み込み・保存
# License: Apache 2.0
import io
import struct
import json
import torch
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
# License: Apache 2.0
import unittest
import torch
import tempfile
import os
from safetensors import safe_open
from safetensors.torch import save_file
from mem_eff_safeopen import MemoryEfficientSafeOpen
class TestMemoryEfficientSafeOpen(unittest.TestCase):
def setUp(self):
self.test_tensors = {
"float32": torch.randn(10, 20).float(),
"float16": torch.randn(5, 15).half(),
"int64": torch.randint(-100, 100, (8, 12)).long(),
"bool": torch.randint(0, 2, (6, 6)).bool(),
"empty": torch.empty(0, 10),
"scalar": torch.tensor(3.14),
}
if hasattr(torch, "bfloat16"):
self.test_tensors["bfloat16"] = torch.randn(7, 9).to(torch.bfloat16)
if hasattr(torch, "float8_e5m2"):
self.test_tensors["float8_e5m2"] = torch.randn(4, 8).to(torch.float8_e5m2)
if hasattr(torch, "float8_e4m3fn"):
self.test_tensors["float8_e4m3fn"] = torch.randn(3, 7).to(torch.float8_e4m3fn)
def test_tensor_loading(self):
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp_filename = tmp.name
try:
# 1. テスト用の.safetensorsファイルを作成
save_file(self.test_tensors, tmp_filename)
# 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較
with safe_open(tmp_filename, framework="pt", device="cpu") as f:
official_tensors = {key: f.get_tensor(key) for key in f.keys()}
with MemoryEfficientSafeOpen(tmp_filename) as f:
efficient_tensors = {key: f.get_tensor(key) for key in f.keys()}
# 3. 各テンソルについて比較
for key in self.test_tensors.keys():
dtype = self.test_tensors[key].dtype
if "float8" in str(dtype):
# float8型の場合はtorch.allcloseが使えないので、要素ごとに比較
for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)):
self.assertAlmostEqual(a.item(), b.item(), delta=1e-2)
else:
self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3))
self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape)
self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype)
finally:
os.unlink(tmp_filename)
def test_memory_efficiency(self):
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp_filename = tmp.name
try:
# 大きなテンソルを作成
num_tensors = 100
large_tensors = {f"large_{i}": torch.randn(1000, 1000) for i in range(num_tensors)}
save_file(large_tensors, tmp_filename)
# メモリ使用量を測定(簡易的な方法)
import psutil
import gc
process = psutil.Process()
def get_memory_usage():
return process.memory_info().rss / 1024 / 1024 # MB単位
# 公式safetensorsでの読み込み
gc.collect()
mem_before = get_memory_usage()
with safe_open(tmp_filename, framework="pt", device="cpu") as f:
for key in f.keys():
t = f.get_tensor(key)
t = t.mul(2) # 何か操作を行い実際にメモリに読み込む
del t
gc.collect()
mem_after_official = get_memory_usage()
# MemoryEfficientSafeOpenでの読み込み
gc.collect()
mem_before = get_memory_usage()
with MemoryEfficientSafeOpen(tmp_filename) as f:
for key in f.keys():
t = f.get_tensor(key)
t = t.mul(2) # すでに読み込まれている
del t
gc.collect()
mem_after_efficient = get_memory_usage()
# メモリ使用量の比較
self.assertLess(mem_after_efficient - mem_before, mem_after_official - mem_before)
finally:
os.unlink(tmp_filename)
if __name__ == "__main__":
unittest.main()
# License: Apache 2.0
import torch
import json
import struct
from typing import Dict, Any
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
# Usage example
if __name__ == "__main__":
# Create some example tensors on GPU
tensors = {"weight": torch.randn(1000, 1000, device="cuda"), "bias": torch.randn(1000, device="cuda")}
metadata = {"model_type": "example", "version": "1.0"}
mem_eff_save_file(tensors, "model.safetensors", metadata)
# License: Apache 2.0
import unittest
import torch
import os
import tempfile
from safetensors.torch import load_file as official_load_file
from safetensors import safe_open
from mem_eff_save_file import mem_eff_save_file # あなたの実装
class TestCompatibilityWithOfficialSafetensors(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
def tearDown(self):
for file in os.listdir(self.temp_dir):
os.remove(os.path.join(self.temp_dir, file))
os.rmdir(self.temp_dir)
def assert_tensors_equal(self, tensor1, tensor2):
self.assertTrue(torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8), f"Tensors are not equal: {tensor1} vs {tensor2}")
def test_compatibility_cpu_tensor(self):
tensor = torch.randn(100, 100)
tensors = {"test": tensor}
file_path = os.path.join(self.temp_dir, "custom_cpu.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
def test_compatibility_not_contiguous_cpu_tensor(self):
tensor = torch.randn(100, 100)
tensor = tensor[:, ::2]
tensors = {"test": tensor}
assert not tensor.is_contiguous(), "Tensor must not be contiguous"
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_cpu.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_compatibility_gpu_tensor(self):
tensor = torch.randn(100, 100, device="cuda")
tensors = {"test": tensor}
file_path = os.path.join(self.temp_dir, "custom_gpu.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_compatibility_not_contiguous_gpu_tensor(self):
tensor = torch.randn(100, 100, device="cuda")
tensor = tensor[:, ::2]
tensors = {"test": tensor}
assert not tensor.is_contiguous(), "Tensor must not be contiguous"
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_gpu.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key])
def test_compatibility_multiple_tensors(self):
tensors = {"weight": torch.randn(100, 100), "bias": torch.randn(100)}
file_path = os.path.join(self.temp_dir, "custom_multiple.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
def test_compatibility_with_empty_tensors(self):
tensors = {"empty": torch.tensor([]), "zero_dim": torch.tensor(1)}
file_path = os.path.join(self.temp_dir, "custom_empty.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
def test_compatibility_different_dtypes(self):
tensors = {
"float32": torch.randn(10, 10, dtype=torch.float32),
"float16": torch.randn(10, 10, dtype=torch.float16),
"int32": torch.randint(0, 10, (10, 10), dtype=torch.int32),
}
file_path = os.path.join(self.temp_dir, "custom_dtypes.safetensors")
mem_eff_save_file(tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
self.assertEqual(tensors[key].dtype, loaded_tensors[key].dtype)
def test_compatibility_with_metadata(self):
tensor = torch.randn(10, 10)
tensors = {"test": tensor}
metadata = {"model_type": "test", "version": "1.0"}
file_path = os.path.join(self.temp_dir, "custom_metadata.safetensors")
mem_eff_save_file(tensors, file_path, metadata)
from safetensors import safe_open
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
# load metadata from .safetensors in official implementation
with safe_open(file_path, framework="pt") as f:
official_metadata = f.metadata()
self.assertEqual(metadata, official_metadata)
def test_compatibility_with_metadata_not_str_to_str(self):
tensor = torch.randn(10, 10)
tensors = {"test": tensor}
metadata = {"model_type": "test", "version": 1.0}
file_path = os.path.join(self.temp_dir, "custom_metadata_not_str_to_str.safetensors")
mem_eff_save_file(tensors, file_path, metadata)
from safetensors import safe_open
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys()))
for key in tensors:
self.assert_tensors_equal(tensors[key], loaded_tensors[key])
# load metadata from .safetensors in official implementation
with safe_open(file_path, framework="pt") as f:
official_metadata = f.metadata()
self.assertEqual({"model_type": "test", "version": "1.0"}, official_metadata)
def test_large_model_compatibility(self):
# 大規模なモデルをシミュレート
large_tensors = {f"layer_{i}": torch.randn(1000, 1000) for i in range(10)}
file_path = os.path.join(self.temp_dir, "large_model.safetensors")
mem_eff_save_file(large_tensors, file_path)
loaded_tensors = official_load_file(file_path)
self.assertEqual(set(large_tensors.keys()), set(loaded_tensors.keys()))
for key in large_tensors:
self.assert_tensors_equal(large_tensors[key], loaded_tensors[key])
if __name__ == "__main__":
unittest.main()
@kohya-ss
Copy link
Author

お役に立ったようで幸いです。読み込み時のメモリ消費を削減するスクリプトもありますので、必要なら公開いたします。ご連絡いただければ幸いです。/ I'm glad it was helpful. I also have a script to reduce memory consumption when loading, so I'll release it if necessary. Please let me know.

@kohya-ss
Copy link
Author

せっかくなので読み込みも付けておきました。各ファイルにライセンスも追記しました。動作は無保証ですので、ご理解の上ご利用ください。 / I've added a loading feature as well. I've also included license information in each file. Please note that this code comes with no warranty - use it at your own discretion.

@Stella2211
Copy link

ありがとうございます!早速組み込ませていただきました! / Thank you very much! I have incorporated it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment