Last active
August 23, 2024 03:41
-
-
Save kohya-ss/fa4b7ae7119c10850ae7d70c90a59277 to your computer and use it in GitHub Desktop.
メインメモリを消費しないsafetensorsファイル読み込み・保存
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import io | |
import struct | |
import json | |
import torch | |
class MemoryEfficientSafeOpen: | |
# does not support metadata loading | |
def __init__(self, filename): | |
self.filename = filename | |
self.header, self.header_size = self._read_header() | |
self.file = open(filename, "rb") | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.file.close() | |
def keys(self): | |
return [k for k in self.header.keys() if k != "__metadata__"] | |
def get_tensor(self, key): | |
if key not in self.header: | |
raise KeyError(f"Tensor '{key}' not found in the file") | |
metadata = self.header[key] | |
offset_start, offset_end = metadata["data_offsets"] | |
if offset_start == offset_end: | |
tensor_bytes = None | |
else: | |
# adjust offset by header size | |
self.file.seek(self.header_size + 8 + offset_start) | |
tensor_bytes = self.file.read(offset_end - offset_start) | |
return self._deserialize_tensor(tensor_bytes, metadata) | |
def _read_header(self): | |
with open(self.filename, "rb") as f: | |
header_size = struct.unpack("<Q", f.read(8))[0] | |
header_json = f.read(header_size).decode("utf-8") | |
return json.loads(header_json), header_size | |
def _deserialize_tensor(self, tensor_bytes, metadata): | |
dtype = self._get_torch_dtype(metadata["dtype"]) | |
shape = metadata["shape"] | |
if tensor_bytes is None: | |
byte_tensor = torch.empty(0, dtype=torch.uint8) | |
else: | |
tensor_bytes = bytearray(tensor_bytes) # make it writable | |
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
# process float8 types | |
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
# convert to the target dtype and reshape | |
return byte_tensor.view(dtype).reshape(shape) | |
@staticmethod | |
def _get_torch_dtype(dtype_str): | |
dtype_map = { | |
"F64": torch.float64, | |
"F32": torch.float32, | |
"F16": torch.float16, | |
"BF16": torch.bfloat16, | |
"I64": torch.int64, | |
"I32": torch.int32, | |
"I16": torch.int16, | |
"I8": torch.int8, | |
"U8": torch.uint8, | |
"BOOL": torch.bool, | |
} | |
# add float8 types if available | |
if hasattr(torch, "float8_e5m2"): | |
dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
if hasattr(torch, "float8_e4m3fn"): | |
dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
return dtype_map.get(dtype_str) | |
@staticmethod | |
def _convert_float8(byte_tensor, dtype_str, shape): | |
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
else: | |
# # convert to float16 if float8 is not supported | |
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import unittest | |
import torch | |
import tempfile | |
import os | |
from safetensors import safe_open | |
from safetensors.torch import save_file | |
from mem_eff_safeopen import MemoryEfficientSafeOpen | |
class TestMemoryEfficientSafeOpen(unittest.TestCase): | |
def setUp(self): | |
self.test_tensors = { | |
"float32": torch.randn(10, 20).float(), | |
"float16": torch.randn(5, 15).half(), | |
"int64": torch.randint(-100, 100, (8, 12)).long(), | |
"bool": torch.randint(0, 2, (6, 6)).bool(), | |
"empty": torch.empty(0, 10), | |
"scalar": torch.tensor(3.14), | |
} | |
if hasattr(torch, "bfloat16"): | |
self.test_tensors["bfloat16"] = torch.randn(7, 9).to(torch.bfloat16) | |
if hasattr(torch, "float8_e5m2"): | |
self.test_tensors["float8_e5m2"] = torch.randn(4, 8).to(torch.float8_e5m2) | |
if hasattr(torch, "float8_e4m3fn"): | |
self.test_tensors["float8_e4m3fn"] = torch.randn(3, 7).to(torch.float8_e4m3fn) | |
def test_tensor_loading(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
try: | |
# 1. テスト用の.safetensorsファイルを作成 | |
save_file(self.test_tensors, tmp_filename) | |
# 2. 公式safetensorsとMemoryEfficientSafeOpenで読み込み、比較 | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
official_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
efficient_tensors = {key: f.get_tensor(key) for key in f.keys()} | |
# 3. 各テンソルについて比較 | |
for key in self.test_tensors.keys(): | |
dtype = self.test_tensors[key].dtype | |
if "float8" in str(dtype): | |
# float8型の場合はtorch.allcloseが使えないので、要素ごとに比較 | |
for a, b in zip(official_tensors[key].view(-1), efficient_tensors[key].view(-1)): | |
self.assertAlmostEqual(a.item(), b.item(), delta=1e-2) | |
else: | |
self.assertTrue(torch.allclose(official_tensors[key], efficient_tensors[key], atol=1e-5, rtol=1e-3)) | |
self.assertEqual(official_tensors[key].shape, efficient_tensors[key].shape) | |
self.assertEqual(official_tensors[key].dtype, efficient_tensors[key].dtype) | |
finally: | |
os.unlink(tmp_filename) | |
def test_memory_efficiency(self): | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
tmp_filename = tmp.name | |
try: | |
# 大きなテンソルを作成 | |
num_tensors = 100 | |
large_tensors = {f"large_{i}": torch.randn(1000, 1000) for i in range(num_tensors)} | |
save_file(large_tensors, tmp_filename) | |
# メモリ使用量を測定(簡易的な方法) | |
import psutil | |
import gc | |
process = psutil.Process() | |
def get_memory_usage(): | |
return process.memory_info().rss / 1024 / 1024 # MB単位 | |
# 公式safetensorsでの読み込み | |
gc.collect() | |
mem_before = get_memory_usage() | |
with safe_open(tmp_filename, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
t = f.get_tensor(key) | |
t = t.mul(2) # 何か操作を行い実際にメモリに読み込む | |
del t | |
gc.collect() | |
mem_after_official = get_memory_usage() | |
# MemoryEfficientSafeOpenでの読み込み | |
gc.collect() | |
mem_before = get_memory_usage() | |
with MemoryEfficientSafeOpen(tmp_filename) as f: | |
for key in f.keys(): | |
t = f.get_tensor(key) | |
t = t.mul(2) # すでに読み込まれている | |
del t | |
gc.collect() | |
mem_after_efficient = get_memory_usage() | |
# メモリ使用量の比較 | |
self.assertLess(mem_after_efficient - mem_before, mem_after_official - mem_before) | |
finally: | |
os.unlink(tmp_filename) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import torch | |
import json | |
import struct | |
from typing import Dict, Any | |
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
_TYPES = { | |
torch.float64: "F64", | |
torch.float32: "F32", | |
torch.float16: "F16", | |
torch.bfloat16: "BF16", | |
torch.int64: "I64", | |
torch.int32: "I32", | |
torch.int16: "I16", | |
torch.int8: "I8", | |
torch.uint8: "U8", | |
torch.bool: "BOOL", | |
getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
} | |
_ALIGN = 256 | |
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
validated = {} | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
if not isinstance(value, str): | |
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
validated[key] = str(value) | |
else: | |
validated[key] = value | |
return validated | |
header = {} | |
offset = 0 | |
if metadata: | |
header["__metadata__"] = validate_metadata(metadata) | |
for k, v in tensors.items(): | |
if v.numel() == 0: # empty tensor | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
else: | |
size = v.numel() * v.element_size() | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
offset += size | |
hjson = json.dumps(header).encode("utf-8") | |
hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
with open(filename, "wb") as f: | |
f.write(struct.pack("<Q", len(hjson))) | |
f.write(hjson) | |
for k, v in tensors.items(): | |
if v.numel() == 0: | |
continue | |
if v.is_cuda: | |
# Direct GPU to disk save | |
with torch.cuda.device(v.device): | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
tensor_bytes = v.contiguous().view(torch.uint8) | |
tensor_bytes.cpu().numpy().tofile(f) | |
else: | |
# CPU tensor save | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
v.contiguous().view(torch.uint8).numpy().tofile(f) | |
# Usage example | |
if __name__ == "__main__": | |
# Create some example tensors on GPU | |
tensors = {"weight": torch.randn(1000, 1000, device="cuda"), "bias": torch.randn(1000, device="cuda")} | |
metadata = {"model_type": "example", "version": "1.0"} | |
mem_eff_save_file(tensors, "model.safetensors", metadata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: Apache 2.0 | |
import unittest | |
import torch | |
import os | |
import tempfile | |
from safetensors.torch import load_file as official_load_file | |
from safetensors import safe_open | |
from mem_eff_save_file import mem_eff_save_file # あなたの実装 | |
class TestCompatibilityWithOfficialSafetensors(unittest.TestCase): | |
def setUp(self): | |
self.temp_dir = tempfile.mkdtemp() | |
def tearDown(self): | |
for file in os.listdir(self.temp_dir): | |
os.remove(os.path.join(self.temp_dir, file)) | |
os.rmdir(self.temp_dir) | |
def assert_tensors_equal(self, tensor1, tensor2): | |
self.assertTrue(torch.allclose(tensor1, tensor2, rtol=1e-5, atol=1e-8), f"Tensors are not equal: {tensor1} vs {tensor2}") | |
def test_compatibility_cpu_tensor(self): | |
tensor = torch.randn(100, 100) | |
tensors = {"test": tensor} | |
file_path = os.path.join(self.temp_dir, "custom_cpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_not_contiguous_cpu_tensor(self): | |
tensor = torch.randn(100, 100) | |
tensor = tensor[:, ::2] | |
tensors = {"test": tensor} | |
assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_cpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
def test_compatibility_gpu_tensor(self): | |
tensor = torch.randn(100, 100, device="cuda") | |
tensors = {"test": tensor} | |
file_path = os.path.join(self.temp_dir, "custom_gpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
def test_compatibility_not_contiguous_gpu_tensor(self): | |
tensor = torch.randn(100, 100, device="cuda") | |
tensor = tensor[:, ::2] | |
tensors = {"test": tensor} | |
assert not tensor.is_contiguous(), "Tensor must not be contiguous" | |
file_path = os.path.join(self.temp_dir, "custom_not_contiguous_gpu.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key].cpu(), loaded_tensors[key]) | |
def test_compatibility_multiple_tensors(self): | |
tensors = {"weight": torch.randn(100, 100), "bias": torch.randn(100)} | |
file_path = os.path.join(self.temp_dir, "custom_multiple.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_with_empty_tensors(self): | |
tensors = {"empty": torch.tensor([]), "zero_dim": torch.tensor(1)} | |
file_path = os.path.join(self.temp_dir, "custom_empty.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
def test_compatibility_different_dtypes(self): | |
tensors = { | |
"float32": torch.randn(10, 10, dtype=torch.float32), | |
"float16": torch.randn(10, 10, dtype=torch.float16), | |
"int32": torch.randint(0, 10, (10, 10), dtype=torch.int32), | |
} | |
file_path = os.path.join(self.temp_dir, "custom_dtypes.safetensors") | |
mem_eff_save_file(tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
self.assertEqual(tensors[key].dtype, loaded_tensors[key].dtype) | |
def test_compatibility_with_metadata(self): | |
tensor = torch.randn(10, 10) | |
tensors = {"test": tensor} | |
metadata = {"model_type": "test", "version": "1.0"} | |
file_path = os.path.join(self.temp_dir, "custom_metadata.safetensors") | |
mem_eff_save_file(tensors, file_path, metadata) | |
from safetensors import safe_open | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
# load metadata from .safetensors in official implementation | |
with safe_open(file_path, framework="pt") as f: | |
official_metadata = f.metadata() | |
self.assertEqual(metadata, official_metadata) | |
def test_compatibility_with_metadata_not_str_to_str(self): | |
tensor = torch.randn(10, 10) | |
tensors = {"test": tensor} | |
metadata = {"model_type": "test", "version": 1.0} | |
file_path = os.path.join(self.temp_dir, "custom_metadata_not_str_to_str.safetensors") | |
mem_eff_save_file(tensors, file_path, metadata) | |
from safetensors import safe_open | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(tensors.keys()), set(loaded_tensors.keys())) | |
for key in tensors: | |
self.assert_tensors_equal(tensors[key], loaded_tensors[key]) | |
# load metadata from .safetensors in official implementation | |
with safe_open(file_path, framework="pt") as f: | |
official_metadata = f.metadata() | |
self.assertEqual({"model_type": "test", "version": "1.0"}, official_metadata) | |
def test_large_model_compatibility(self): | |
# 大規模なモデルをシミュレート | |
large_tensors = {f"layer_{i}": torch.randn(1000, 1000) for i in range(10)} | |
file_path = os.path.join(self.temp_dir, "large_model.safetensors") | |
mem_eff_save_file(large_tensors, file_path) | |
loaded_tensors = official_load_file(file_path) | |
self.assertEqual(set(large_tensors.keys()), set(loaded_tensors.keys())) | |
for key in large_tensors: | |
self.assert_tensors_equal(large_tensors[key], loaded_tensors[key]) | |
if __name__ == "__main__": | |
unittest.main() |
お役に立ったようで幸いです。読み込み時のメモリ消費を削減するスクリプトもありますので、必要なら公開いたします。ご連絡いただければ幸いです。/ I'm glad it was helpful. I also have a script to reduce memory consumption when loading, so I'll release it if necessary. Please let me know.
せっかくなので読み込みも付けておきました。各ファイルにライセンスも追記しました。動作は無保証ですので、ご理解の上ご利用ください。 / I've added a loading feature as well. I've also included license information in each file. Please note that this code comes with no warranty - use it at your own discretion.
ありがとうございます!早速組み込ませていただきました! / Thank you very much! I have incorporated it!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
こちらのコードをKijai氏のfp8化コードと組み合わせて、省メモリのfp8化スクリプトを作成しました。コードを共有してくださりありがとうございます!/ I combined this code with Kijai's fp8 conversion code to create a memory-efficient fp8 conversion script. Thank you for sharing the code!