Last active
November 28, 2024 22:26
-
-
Save pcuenca/d18238b31f9a3a9acf41a9a4817ca209 to your computer and use it in GitHub Desktop.
Simple reshard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simple Llama resharding | |
# Tested on Llama 3.1 70B only | |
# Shards are released from memory as soon as we're done with each | |
import torch | |
from pathlib import Path | |
checkpoint = Path("original") | |
output_dir = Path("resharded") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Number of checkpoints to join as a new shard (use 2 to create 4 shards, 4 for 2 and 8 for 1) | |
consolidate = 2 | |
num_shards = 8 | |
# from params.json | |
shard_dim = 8192 | |
from tqdm import tqdm | |
for s1 in tqdm(range(0, num_shards, consolidate)): | |
resharded = torch.load(checkpoint/f"consolidated.{s1:02}.pth", weights_only=True) | |
for s2 in range(s1+1, s1+consolidate): | |
new_shard = torch.load(checkpoint/f"consolidated.{s2:02}.pth", weights_only=True) | |
current_keys = set(resharded.keys()) | |
for key, tensor in new_shard.items(): | |
# New key -> Add it | |
if not key in current_keys: | |
resharded[key] = tensor | |
continue | |
# Ignore unidimensional tensors already present | |
if tensor.ndim == 1: | |
continue | |
# Concat two-dimensional tensors | |
# along the dimension that is not 8192 lol | |
assert tensor.ndim == 2 | |
cat_dim = 1 if tensor.shape[0] == shard_dim else 0 | |
resharded[key] = torch.cat([resharded[key], tensor], dim=cat_dim) | |
del new_shard | |
# Save shard | |
torch.save(resharded, output_dir/f"consolidated.{s1//consolidate:02}.pth") | |
del resharded |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment