Created
December 12, 2024 17:45
-
-
Save d4l3k/b68094d649a076384967788c9b0a5f08 to your computer and use it in GitHub Desktop.
torch.save/load benchmark and streaming implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import pickle | |
from io import BufferedIOBase | |
from typing import Tuple | |
import tempfile | |
import time | |
import struct | |
import torch | |
from torch.utils._pytree import tree_flatten, tree_unflatten | |
@dataclass | |
class TensorMetadata: | |
nbytes: int | |
dtype: torch.dtype | |
storage_offset: int | |
size: Tuple[int, ...] | |
stride: Tuple[int, ...] | |
def write_state_dict(state_dict: object, f: BufferedIOBase) -> None: | |
""" | |
Write the state_dict to the file-like object. | |
This is optimized to minimize the number of memory copies and is | |
significantly faster than torch.save/load as loading doesn't require the | |
entire serialized state_dict to be in memory. | |
This uses pytree to separate the structure of the state_dict from the | |
values/tensors and writing the structure via pickle and the underlying | |
tensor storage directly to the file. | |
Wire Format: | |
- pickle length: 8 bytes | |
- pickle data: pickle length bytes | |
- tensor0 storage: n bytes | |
- ... | |
- tensorN storage: m bytes | |
""" | |
# Use pytree to flatten the state_dict into the state_dict leaf values and | |
# the Python tree structure. This allows us to operate on all the tensors in | |
# the arbitrary Python state_dict using a simple for loop. | |
values, spec = tree_flatten(state_dict) | |
storages = [] | |
non_tensor_values = [] | |
for value in values: | |
if isinstance(value, torch.Tensor): | |
storage = value.untyped_storage() | |
storages.append(storage) | |
non_tensor_values.append( | |
TensorMetadata( | |
nbytes=storage.nbytes(), | |
dtype=value.dtype, | |
storage_offset=value.storage_offset(), | |
size=value.size(), | |
stride=value.stride(), | |
) | |
) | |
else: | |
non_tensor_values.append(value) | |
meta_buf = pickle.dumps((non_tensor_values, spec)) | |
f.write(struct.pack("<q", len(meta_buf))) | |
f.write(meta_buf) | |
for storage in storages: | |
# This directly writes the underlying storage buffer to the file-like | |
# object. | |
# (f, is_real_file, save_size, element_size) | |
storage._write_file(f, False, False, 1) | |
def read_state_dict(f: BufferedIOBase) -> object: | |
""" | |
Read the state_dict from the file-like object. | |
See `write_state_dict` for the format. | |
""" | |
meta_len = struct.unpack("<q", f.read(8))[0] | |
non_tensor_values, spec = pickle.loads(f.read(meta_len)) | |
values = [] | |
for value in non_tensor_values: | |
if isinstance(value, TensorMetadata): | |
# Since we know the tensor sizes upfront we can then just read each | |
# tensor storage directly from the wire. | |
data = f.read(value.nbytes) | |
tensor = torch.as_strided( | |
# This takes ownership of the bytes object which is normally | |
# immutable but it's fine in this case since only PyTorch is | |
# using it. | |
torch.frombuffer(data, dtype=value.dtype), | |
size=value.size, | |
stride=value.stride, | |
storage_offset=value.storage_offset, | |
) | |
values.append(tensor) | |
else: | |
values.append(value) | |
return tree_unflatten(values, spec) | |
def main() -> None: | |
# get rid of warnings early | |
torch.frombuffer(b"1234", dtype=torch.float32) | |
print("creating state dict...") | |
state_dict = {} | |
chunk_size = 1024 * 1024 * 1024 # 64MB | |
total_size = 16 * 1000 * 1000 * 1000 # 16GB | |
for i in range(0, total_size, chunk_size): | |
state_dict[f"chunk_{i}"] = torch.zeros(chunk_size//4, dtype=torch.float32) | |
print("starting benchmark...") | |
for i in range(0, 10): | |
print(f"iteration {i}") | |
with tempfile.TemporaryFile() as fp: | |
start = time.perf_counter() | |
write_state_dict(state_dict, fp) | |
print(f"write_state_dict took {time.perf_counter() - start} seconds") | |
fp.seek(0) | |
start = time.perf_counter() | |
read_state_dict(fp) | |
print(f"read_state_dict took {time.perf_counter() - start} seconds") | |
with tempfile.TemporaryFile() as fp: | |
start = time.perf_counter() | |
torch.save(state_dict, fp) | |
print(f"torch.save took {time.perf_counter() - start} seconds") | |
fp.seek(0) | |
start = time.perf_counter() | |
torch.load(fp, weights_only=True) | |
print(f"torch.load took {time.perf_counter() - start} seconds") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment