Last active
November 1, 2024 12:48
-
-
Save Stella2211/10f5bd870387ec1ddb9932235321068e to your computer and use it in GitHub Desktop.
メモリ効率のいいfp8化スクリプト。 / Memory efficient fp8 conversion script.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pathlib import Path | |
import torch | |
from tqdm import tqdm | |
import struct | |
from typing import Dict, Any | |
import sys | |
# input file | |
if(len(sys.argv) < 3): | |
print("Usage: mem_eff_fp8_convert.py {fp16 model path} {output path}") | |
sys.exit(1) | |
path = sys.argv[1] | |
output =sys.argv[2] | |
model_file = Path(path) | |
class MemoryEfficientSafeOpen: | |
# does not support metadata loading | |
def __init__(self, filename): | |
self.filename = filename | |
self.header, self.header_size = self._read_header() | |
self.file = open(filename, "rb") | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.file.close() | |
def keys(self): | |
return [k for k in self.header.keys() if k != "__metadata__"] | |
def get_tensor(self, key): | |
if key not in self.header: | |
raise KeyError(f"Tensor '{key}' not found in the file") | |
metadata = self.header[key] | |
offset_start, offset_end = metadata["data_offsets"] | |
if offset_start == offset_end: | |
tensor_bytes = None | |
else: | |
# adjust offset by header size | |
self.file.seek(self.header_size + 8 + offset_start) | |
tensor_bytes = self.file.read(offset_end - offset_start) | |
return self._deserialize_tensor(tensor_bytes, metadata) | |
def _read_header(self): | |
with open(self.filename, "rb") as f: | |
header_size = struct.unpack("<Q", f.read(8))[0] | |
header_json = f.read(header_size).decode("utf-8") | |
return json.loads(header_json), header_size | |
def _deserialize_tensor(self, tensor_bytes, metadata): | |
dtype = self._get_torch_dtype(metadata["dtype"]) | |
shape = metadata["shape"] | |
if tensor_bytes is None: | |
byte_tensor = torch.empty(0, dtype=torch.uint8) | |
else: | |
tensor_bytes = bytearray(tensor_bytes) # make it writable | |
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
# process float8 types | |
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
# convert to the target dtype and reshape | |
return byte_tensor.view(dtype).reshape(shape) | |
@staticmethod | |
def _get_torch_dtype(dtype_str): | |
dtype_map = { | |
"F64": torch.float64, | |
"F32": torch.float32, | |
"F16": torch.float16, | |
"BF16": torch.bfloat16, | |
"I64": torch.int64, | |
"I32": torch.int32, | |
"I16": torch.int16, | |
"I8": torch.int8, | |
"U8": torch.uint8, | |
"BOOL": torch.bool, | |
} | |
# add float8 types if available | |
if hasattr(torch, "float8_e5m2"): | |
dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
if hasattr(torch, "float8_e4m3fn"): | |
dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
return dtype_map.get(dtype_str) | |
@staticmethod | |
def _convert_float8(byte_tensor, dtype_str, shape): | |
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
else: | |
# # convert to float16 if float8 is not supported | |
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.") | |
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape) | |
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
_TYPES = { | |
torch.float64: "F64", | |
torch.float32: "F32", | |
torch.float16: "F16", | |
torch.bfloat16: "BF16", | |
torch.int64: "I64", | |
torch.int32: "I32", | |
torch.int16: "I16", | |
torch.int8: "I8", | |
torch.uint8: "U8", | |
torch.bool: "BOOL", | |
getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
} | |
_ALIGN = 256 | |
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
validated = {} | |
for key, value in metadata.items(): | |
if not isinstance(key, str): | |
raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
if not isinstance(value, str): | |
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
validated[key] = str(value) | |
else: | |
validated[key] = value | |
return validated | |
header = {} | |
offset = 0 | |
if metadata: | |
header["__metadata__"] = validate_metadata(metadata) | |
for k, v in tensors.items(): | |
if v.numel() == 0: # empty tensor | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
else: | |
size = v.numel() * v.element_size() | |
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
offset += size | |
hjson = json.dumps(header).encode("utf-8") | |
hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
with open(filename, "wb") as f: | |
f.write(struct.pack("<Q", len(hjson))) | |
f.write(hjson) | |
for k, v in tensors.items(): | |
if v.numel() == 0: | |
continue | |
if v.is_cuda: | |
# Direct GPU to disk save | |
with torch.cuda.device(v.device): | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
tensor_bytes = v.contiguous().view(torch.uint8) | |
tensor_bytes.cpu().numpy().tofile(f) | |
else: | |
# CPU tensor save | |
if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
v = v.unsqueeze(0) | |
v.contiguous().view(torch.uint8).numpy().tofile(f) | |
# read safetensors metadata | |
def read_safetensors_metadata(path: str): | |
with open(path, 'rb') as f: | |
header_size = int.from_bytes(f.read(8), 'little') | |
header_json = f.read(header_size).decode('utf-8') | |
header = json.loads(header_json) | |
metadata = header.get('__metadata__', {}) | |
return metadata | |
metadata = read_safetensors_metadata(path) | |
print(json.dumps(metadata, indent=4)) #show metadata | |
sd_pruned = dict() #initialize empty dict | |
with MemoryEfficientSafeOpen(path) as reader: | |
keys = reader.keys() | |
for key in tqdm(keys): #for each key in the safetensors file | |
sd_pruned[key] = reader.get_tensor(key).to(torch.float8_e4m3fn) #convert to fp8 | |
# save the pruned safetensors file | |
mem_eff_save_file(sd_pruned, output, metadata={"format": "pt", **metadata}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Good work!