Last active
January 27, 2024 19:41
-
-
Save nalzok/b88b192d1c8ed65d66bd603d1c5444b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import perf_counter | |
import torch | |
from flash_attn_time import benchmark_one | |
from huggingface_hub import snapshot_download | |
from huggingface_hub.utils._errors import RepositoryNotFoundError | |
def benchmark(): | |
for llama in (1, 2): | |
sizes = (7, 13, 30, 65) if llama == 1 else (7, 13, 70) | |
for size in sizes: | |
for rate, method in ((2, "E8P"), (4, "E8PRVQ")): | |
if method is None: | |
publisher = "relaxml" if llama == 1 else "meta-llama" | |
repo_id = f"Llama-{llama}-{size}b-hf" | |
else: | |
publisher = "relaxml" | |
repo_id = f"Llama-{llama}-{size}b-{method}-{rate}Bit" | |
print(">", repo_id) | |
if (llama, size, rate) not in {(2, 7, 2), | |
(2, 7, 4), | |
(2, 13, 2), | |
(2, 13, 4), | |
(2, 70, 2), | |
(1, 30, 2), | |
(1, 30, 4), | |
(1, 65, 2)}: | |
print("Skip") | |
continue | |
try: | |
snapshot_path = snapshot_download(f"{publisher}/{repo_id}") | |
except RepositoryNotFoundError: | |
print("404") | |
continue | |
model_name = f"meta-llama/Llama-2-{size}b-hf" if llama == 2 else f"relaxml/Llama-1-{size}b-hf" | |
try: | |
start_time = perf_counter() | |
benchmark_one(model_name, snapshot_path, method is not None) | |
end_time = perf_counter() | |
except torch.cuda.OutOfMemoryError: | |
print("OOM") | |
continue | |
print("Elapsed", end_time - start_time) | |
if __name__ == "__main__": | |
benchmark() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from lib.linear.quantized_linear import QuantizedLinear | |
from lib.linear.fused_quantized_linear import FusedQuantizedLinear | |
from lib import codebook | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from flash_attn.modules.embedding import GPT2Embeddings | |
from flash_attn.layers.rotary import RotaryEmbedding | |
from flash_attn.modules.mlp import GatedMlp | |
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel, Block | |
from flash_attn.models.llama import llama_config_to_gpt2_config, remap_state_dict_hf_llama | |
from flash_attn.utils.pretrained import state_dict_from_pretrained | |
from flash_attn.utils.benchmark import pytorch_profiler | |
from flash_attn.ops.triton.layer_norm import RMSNorm | |
import os | |
import json | |
device = "cuda" | |
dtype = torch.float16 | |
def benchmark_one(model_name, quip_hf, quantized): | |
llama_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
config = llama_config_to_gpt2_config(llama_config) | |
config.use_flash_attn = True | |
config.fused_dropout_add_ln = True | |
config.residual_in_fp32 = True | |
# pretrained_state_dict = remap_state_dict_hf_llama(state_dict_from_pretrained(model_name), config) | |
def remap_state_dict(state_dict): | |
frozen_keys = list(state_dict.keys()) | |
for key in frozen_keys: | |
new_key = key.replace( | |
'model', 'transformer' | |
).replace( | |
'self_attn.qkv_proj', 'mixer.Wqkv' | |
).replace( | |
'self_attn.o_proj', 'mixer.out_proj' | |
).replace( | |
'mlp.upgate_proj', 'mlp.fc1' | |
).replace( | |
'mlp.down_proj', 'mlp.fc2' | |
).replace( | |
'input_layernorm', 'norm1' | |
).replace( | |
'post_attention_layernorm', 'norm2' | |
).replace( | |
'embed_tokens', 'embeddings.word_embeddings' | |
) | |
if new_key.endswith('Wqkv.fuse_scales'): | |
head_dim = llama_config.hidden_size // llama_config.num_attention_heads | |
device = state_dict[key].device | |
dtype = state_dict[key].dtype | |
fuse_scales = torch.concat([ | |
state_dict[key][0] * torch.ones(llama_config.num_attention_heads * head_dim, device=device, dtype=dtype), | |
state_dict[key][1] * torch.ones(llama_config.num_key_value_heads * head_dim, device=device, dtype=dtype), | |
state_dict[key][1] * torch.ones(llama_config.num_key_value_heads * head_dim, device=device, dtype=dtype), | |
], dim=0) | |
state_dict[new_key] = fuse_scales | |
elif new_key.endswith('fc1.fuse_scales'): | |
# upgate | |
device = state_dict[key].device | |
fuse_scales = torch.concat([ | |
state_dict[key][0] * torch.ones(llama_config.intermediate_size, device=device, dtype=dtype), | |
state_dict[key][1] * torch.ones(llama_config.intermediate_size, device=device, dtype=dtype), | |
], dim=0) | |
state_dict[new_key] = fuse_scales | |
else: | |
if new_key == 'transformer.norm.weight': | |
new_key = 'transformer.ln_f.weight' | |
state_dict[new_key] = state_dict[key] | |
if new_key != key: | |
del(state_dict[key]) | |
return state_dict | |
from lib.utils.unsafe_import import model_from_hf_path | |
m = model_from_hf_path(quip_hf, use_cuda_graph=False)[0].state_dict() | |
model = GPTLMHeadModel(config, device='meta', dtype=dtype) | |
if quantized: | |
m = remap_state_dict(m) | |
quip_params = json.load(open(os.path.join(quip_hf, 'config.json')))['quip_params'] | |
def replace_linear(module): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Linear) and name != 'lm_head': | |
if name.endswith('Wqkv') or name.endswith('fc1'): | |
ql = FusedQuantizedLinear( | |
-1, (child.out_features,), True, | |
child.in_features, | |
child.out_features, | |
quip_params['codesz'], | |
quip_params.get('packsz', 1), | |
quip_params.get('pack_out', False), | |
quip_params['idx_dtype'], | |
quip_params.get('codebook_version', 0), | |
rank=quip_params['lora_rank'], | |
rescale_WH=quip_params['rescale_WH'], | |
resid_scale_override=quip_params.get('resid_scale_override', -1) | |
) | |
else: | |
ql = QuantizedLinear( | |
child.in_features, | |
child.out_features, | |
quip_params['codesz'], | |
quip_params.get('packsz', 1), | |
quip_params.get('pack_out', False), | |
quip_params['idx_dtype'], | |
quip_params.get('codebook_version', 0), | |
rank=quip_params['lora_rank'], | |
rescale_WH=quip_params['rescale_WH'], | |
resid_scale_override=quip_params.get('resid_scale_override', -1) | |
) | |
ql.codebook_id.copy_(codebook.get_id(quip_params['codebook'])) | |
setattr(module, name, ql) | |
else: | |
replace_linear(child) | |
replace_linear(model) | |
model.load_state_dict(m, strict=False, assign=True) | |
#model.load_state_dict(pretrained_state_dict) | |
def replace_meta(module): | |
for name, child in module.named_children(): | |
if isinstance(child, GPTModel) \ | |
or isinstance(child, GPT2Embeddings) \ | |
or isinstance(child, nn.ModuleList) \ | |
or isinstance(child, Block) \ | |
or isinstance(child, GatedMlp): | |
replace_meta(child) | |
elif isinstance(child, nn.Embedding): | |
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda")) | |
elif isinstance(child, RotaryEmbedding): | |
child.inv_freq = torch.randn_like(child.inv_freq, device="cuda") | |
elif isinstance(child, RMSNorm): | |
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda")) | |
elif isinstance(child, nn.Linear): | |
child.weight = nn.Parameter(torch.randn_like(child.weight, device="cuda")) | |
replace_meta(model) | |
model = model.cuda() | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True) | |
input_ids = tokenizer.encode("a", return_tensors="pt").to(device) | |
max_length = input_ids.shape[-1] + 1000 | |
out = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
cg=True, | |
return_dict_in_generate=True, | |
output_scores=True, | |
enable_timing=True, | |
) | |
if __name__ == "__main__": | |
model_name = "relaxml/Llama-1-30b-hf" | |
quip_hf = '/share/desa/nfs01/qs234/huggingface/hub/models--relaxml--Llama-1-30b-E8P-2Bit/snapshots/42807d6d30647886bfc77072871a960a89919f46/' | |
# torch.cuda.memory._record_memory_history(enabled='all') | |
benchmark_one(model_name, quip_hf, True) | |
# from pickle import dump | |
# s = torch.cuda.memory._snapshot() | |
# with open(f"snapshot.pickle", "wb") as f: | |
# dump(s, f) | |
# | |
# torch.cuda.memory._record_memory_history(enabled=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
absl-py==2.1.0 | |
accelerate==0.26.1 | |
aiohttp==3.9.1 | |
aiosignal==1.3.1 | |
annotated-types==0.6.0 | |
anyio==4.2.0 | |
attrs==23.2.0 | |
cachetools==5.3.2 | |
certifi==2023.11.17 | |
chardet==5.2.0 | |
charset-normalizer==3.3.2 | |
click==8.1.7 | |
colorama==0.4.6 | |
DataProperty==1.0.1 | |
datasets==2.16.1 | |
dill==0.3.7 | |
distro==1.9.0 | |
einops==0.7.0 | |
evaluate==0.4.1 | |
fast-hadamard-transform==1.0.1 | |
filelock==3.13.1 | |
flash-attn==2.4.2 | |
frozenlist==1.4.1 | |
fsspec==2023.10.0 | |
fused-dense-lib==0.0.0 | |
glog==0.3.1 | |
h11==0.14.0 | |
httpcore==1.0.2 | |
httpx==0.26.0 | |
huggingface-hub==0.20.3 | |
icdiff==2.0.7 | |
idna==3.6 | |
Jinja2==3.1.3 | |
joblib==1.3.2 | |
jsonlines==4.0.0 | |
lm-eval==0.3.0 | |
lxml==5.1.0 | |
MarkupSafe==2.1.4 | |
mbstrdecoder==1.1.3 | |
mpmath==1.3.0 | |
multidict==6.0.4 | |
multiprocess==0.70.15 | |
networkx==3.2.1 | |
ninja==1.11.1.1 | |
nltk==3.8.1 | |
numexpr==2.8.8 | |
numpy==1.26.3 | |
nvidia-cublas-cu12==12.1.3.1 | |
nvidia-cuda-cupti-cu12==12.1.105 | |
nvidia-cuda-nvrtc-cu12==12.1.105 | |
nvidia-cuda-runtime-cu12==12.1.105 | |
nvidia-cudnn-cu12==8.9.2.26 | |
nvidia-cufft-cu12==11.0.2.54 | |
nvidia-curand-cu12==10.3.2.106 | |
nvidia-cusolver-cu12==11.4.5.107 | |
nvidia-cusparse-cu12==12.1.0.106 | |
nvidia-ml-py==12.535.133 | |
nvidia-nccl-cu12==2.18.1 | |
nvidia-nvjitlink-cu12==12.3.101 | |
nvidia-nvtx-cu12==12.1.105 | |
nvitop==1.3.2 | |
openai==1.9.0 | |
packaging==23.2 | |
pandas==2.2.0 | |
pathvalidate==3.2.0 | |
peft==0.7.1 | |
pillow==10.2.0 | |
portalocker==2.8.2 | |
primefac==2.0.12 | |
psutil==5.9.8 | |
pyarrow==15.0.0 | |
pyarrow-hotfix==0.6 | |
pybind11==2.11.1 | |
pycountry==23.12.11 | |
pydantic==2.5.3 | |
pydantic_core==2.14.6 | |
pytablewriter==1.2.0 | |
python-dateutil==2.8.2 | |
python-gflags==3.1.2 | |
pytz==2023.3.post1 | |
PyYAML==6.0.1 | |
quiptools-cuda==0.0.0 | |
regex==2023.12.25 | |
requests==2.31.0 | |
responses==0.18.0 | |
rouge-score==0.1.2 | |
sacrebleu==1.5.0 | |
safetensors==0.4.2 | |
scikit-learn==1.4.0 | |
scipy==1.12.0 | |
sentencepiece==0.1.99 | |
six==1.16.0 | |
sniffio==1.3.0 | |
sqlitedict==2.1.0 | |
sympy==1.12 | |
tabledata==1.3.3 | |
tabulate==0.9.0 | |
tcolorpy==0.1.4 | |
termcolor==2.4.0 | |
threadpoolctl==3.2.0 | |
tokenizers==0.15.1 | |
torch==2.1.2 | |
torchvision==0.16.2 | |
tqdm==4.66.1 | |
tqdm-multiprocess==0.0.11 | |
transformers==4.37.1 | |
triton==2.1.0 | |
typepy==1.3.2 | |
typing_extensions==4.9.0 | |
tzdata==2023.4 | |
urllib3==2.1.0 | |
xxhash==3.4.1 | |
yarl==1.9.4 | |
zstandard==0.22.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment