Skip to content

Instantly share code, notes, and snippets.

@pcuenca
Last active November 28, 2024 22:26
Show Gist options
  • Save pcuenca/d18238b31f9a3a9acf41a9a4817ca209 to your computer and use it in GitHub Desktop.
Save pcuenca/d18238b31f9a3a9acf41a9a4817ca209 to your computer and use it in GitHub Desktop.
Simple reshard
# Simple Llama resharding
# Tested on Llama 3.1 70B only
# Shards are released from memory as soon as we're done with each
import torch
from pathlib import Path
checkpoint = Path("original")
output_dir = Path("resharded")
output_dir.mkdir(parents=True, exist_ok=True)
# Number of checkpoints to join as a new shard (use 2 to create 4 shards, 4 for 2 and 8 for 1)
consolidate = 2
num_shards = 8
# from params.json
shard_dim = 8192
from tqdm import tqdm
for s1 in tqdm(range(0, num_shards, consolidate)):
resharded = torch.load(checkpoint/f"consolidated.{s1:02}.pth", weights_only=True)
for s2 in range(s1+1, s1+consolidate):
new_shard = torch.load(checkpoint/f"consolidated.{s2:02}.pth", weights_only=True)
current_keys = set(resharded.keys())
for key, tensor in new_shard.items():
# New key -> Add it
if not key in current_keys:
resharded[key] = tensor
continue
# Ignore unidimensional tensors already present
if tensor.ndim == 1:
continue
# Concat two-dimensional tensors
# along the dimension that is not 8192 lol
assert tensor.ndim == 2
cat_dim = 1 if tensor.shape[0] == shard_dim else 0
resharded[key] = torch.cat([resharded[key], tensor], dim=cat_dim)
del new_shard
# Save shard
torch.save(resharded, output_dir/f"consolidated.{s1//consolidate:02}.pth")
del resharded
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment