-
-
Save benob/4850a0210b01672175942203aa36d300 to your computer and use it in GitHub Desktop.
# script to decompose/recompose llama model in different number of shards | |
# note that it loads the full model * 2 in cpu memory | |
import os | |
import json | |
import sys | |
import torch | |
import glob | |
if len(sys.argv) != 4: | |
print('usage: %s <new-shards> <input-model-path> <output-model-path>' % sys.argv[0], file=sys.stderr) | |
sys.exit(1) | |
num_shards = int(sys.argv[1]) | |
input_model_dir = sys.argv[2] | |
output_model_dir = sys.argv[3] | |
with open(os.path.join(input_model_dir, 'params.json'), 'r') as fp: | |
params = json.loads(fp.read()) | |
assert params['dim'] % num_shards == 0, "number of shards need to divide parameter dimension %d" % params['dim'] | |
print('loading...') | |
checkpoints = [torch.load(path, map_location=torch.device('cpu')) for path in glob.glob(os.path.join(input_model_dir, '*.pth'))] | |
layer_kind = { | |
'tok_embeddings': 'ParallelEmbedding', | |
'output': 'ColumnParallelLinear', | |
'attention.wq': 'ColumnParallelLinear', | |
'attention.wk': 'ColumnParallelLinear', | |
'attention.wv': 'ColumnParallelLinear', | |
'attention.wo': 'RowParallelLinear', | |
'feed_forward.w1': 'ColumnParallelLinear', | |
'feed_forward.w2': 'RowParallelLinear', | |
'feed_forward.w3': 'ColumnParallelLinear', | |
'attention_norm': None, | |
'ffn_norm': None, | |
'norm': None, | |
'rope.freqs': None, | |
} | |
output = [dict() for x in range(num_shards)] | |
print('converting...') | |
for key in checkpoints[0].keys(): | |
tensors = [m[key] for m in checkpoints] | |
print(key) | |
print(' in shapes=', [p.shape for p in tensors]) | |
for pattern, kind in layer_kind.items(): | |
if key.replace('.weight', '').endswith(pattern): | |
print(' kind=', kind) | |
if kind == 'ColumnParallelLinear': | |
with torch.no_grad(): | |
merged = torch.cat(tensors, 0) | |
slice_size = merged.shape[0] // num_shards | |
for rank in range(num_shards): | |
output[rank][key] = merged[slice_size * rank: slice_size * (rank + 1),:].clone().detach() | |
elif kind in ('ParallelEmbedding', 'RowParallelLinear'): | |
with torch.no_grad(): | |
merged = torch.cat(tensors, 1) | |
slice_size = merged.shape[1] // num_shards | |
for rank in range(num_shards): | |
output[rank][key] = merged[:,slice_size * rank: slice_size * (rank + 1)].clone().detach() | |
else: | |
for rank in range(num_shards): | |
output[rank][key] = tensors[0] | |
print(' out shapes=', [output[rank][key].shape for rank in range(num_shards)]) | |
print() | |
break | |
else: | |
raise Exception('parameter name not recognized') | |
print('saving...') | |
os.makedirs(output_model_dir, exist_ok=True) | |
with open(os.path.join(output_model_dir, 'params.json'), 'w') as fp: | |
fp.write(json.dumps(params)) | |
for rank in range(num_shards): | |
print(' ', rank) | |
torch.save(output[rank], os.path.join(output_model_dir, 'consolidated.%02d.pth' % rank)) | |
print('done.') |
Why is the resharded file much larger? The 13B has 2 checkpoints totaling 26 GB. After consolidating into 1 file, it jumps to 39 GB
This script doesn't work for 65B models for some reason. It will complete the process but running them output strange tokens. 13-30B shards using this script has no issue. Not sure what makes 65B special causing the sharded model to run incorrectly.
I tried sharding 65B to 2x and 4x for execution on 4x A100 80GB without success. 13-30B has no issue with sharding.
Interestingly I have similar problem (garbage tokens predicted) with 30B model. I'm using single GPU A100 with 80G VARM. The loaded model uses around 67G with batch size of 2.
Except largest model the consolidated files are larger than sharded (x1 stands for 1 shard)
25G /data/LLaMA/13B
37G /data/LLaMA/13Bx1
61G /data/LLaMA/30B
76G /data/LLaMA/30Bx1
122G /data/LLaMA/65B
122G /data/LLaMA/65Bx1
I have succeeded with consolidating 30B model with this script though:
https://github.com/randaller/llama-chat/blob/main/merge-weights.py
How long will it takes to run on A100?
Worked great! Thank you a bunch!
I tested with 65B split in two shards and it worked fine.