Created
August 7, 2024 08:57
-
-
Save abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d to your computer and use it in GitHub Desktop.
An example of how to convert trained EAGLE checkpoint to vLLM compatible version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import torch | |
from safetensors.torch import load_file, save_file | |
ckpt = torch.load("EAGLE-LLaMA3-Instruct-8B/pytorch_model.bin") | |
ref_ckpt = load_file("Meta-Llama-3-8B-Instruct/model-00004-of-00004.safetensors") | |
ckpt['lm_head.weight'] = ref_ckpt['lm_head.weight'] | |
save_file(ckpt, "EAGLE-LLaMA3-Instruct-8B/model.safetensors") | |
with open("EAGLE-LLaMA3-Instruct-8B/config.json") as rf: | |
cfg = json.load(rf) | |
cfg = {"model_type": "eagle", "model": cfg} | |
with open("EAGLE-LLaMA3-Instruct-8B/config.json", "w") as wf: | |
json.dump(cfg, wf) | |
# delete EAGLE-LLaMA3-Instruct-8B/pytorch_model.bin |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@abhigoyal1997
Thank you for sharing your script. Following the issue (vllm-project/vllm#11126), it is necessary to add "eagle_fc_bias" to the configuration. This is particularly relevant for the EAGLE-llama2-chat-7B model, which includes fc_bias, whereas the EAGLE-LLaMA3-Instruct-8B model does not.
I suggest adding the following diff to your gist: