Skip to content

Instantly share code, notes, and snippets.

@danieldk
Created July 24, 2024 11:30
Show Gist options
  • Save danieldk/dea909991793a75be5ea6ed58f1e3c8c to your computer and use it in GitHub Desktop.
Save danieldk/dea909991793a75be5ea6ed58f1e3c8c to your computer and use it in GitHub Desktop.
# Author: Daniel de Kok
# Usage: python shard.py --safetensors-path /fsx/danieldk/4bit-gptq-instruct/gptq_model-4bit-128g.safetensors --framework torch --output-path /fsx/danieldk/4bit-gptq-instruct/gptq-sharded
import argparse
import safetensors
import huggingface_hub
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--safetensors-path", type=str)
parser.add_argument("--framework", type=str)
parser.add_argument("--output-path", type=str)
return parser.parse_args()
class FakeStateDict:
def __init__(self, weights):
self.weights = weights
def items(self):
for key in self.weights.keys():
yield key, self.weights.get_tensor(key)
def __getitem__(self, item):
return self.weights.get_tensor(item)
if __name__ == "__main__":
args = get_args()
weights = safetensors.safe_open(args.safetensors_path, framework=args.framework)
huggingface_hub.save_torch_state_dict(FakeStateDict(weights), args.output_path) # type: ignore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment