Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d to your computer and use it in GitHub Desktop.
Save abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d to your computer and use it in GitHub Desktop.
An example of how to convert trained EAGLE checkpoint to vLLM compatible version
import json
import torch
from safetensors.torch import load_file, save_file
ckpt = torch.load("EAGLE-LLaMA3-Instruct-8B/pytorch_model.bin")
ref_ckpt = load_file("Meta-Llama-3-8B-Instruct/model-00004-of-00004.safetensors")
ckpt['lm_head.weight'] = ref_ckpt['lm_head.weight']
save_file(ckpt, "EAGLE-LLaMA3-Instruct-8B/model.safetensors")
with open("EAGLE-LLaMA3-Instruct-8B/config.json") as rf:
cfg = json.load(rf)
cfg = {"model_type": "eagle", "model": cfg}
with open("EAGLE-LLaMA3-Instruct-8B/config.json", "w") as wf:
json.dump(cfg, wf)
# delete EAGLE-LLaMA3-Instruct-8B/pytorch_model.bin
@llsj14
Copy link

llsj14 commented Dec 19, 2024

@abhigoyal1997
Thank you for sharing your script. Following the issue (vllm-project/vllm#11126), it is necessary to add "eagle_fc_bias" to the configuration. This is particularly relevant for the EAGLE-llama2-chat-7B model, which includes fc_bias, whereas the EAGLE-LLaMA3-Instruct-8B model does not.

I suggest adding the following diff to your gist:

with open("./config.json") as rf:
    cfg = json.load(rf)

+ if "fc.bias" in ckpt:
+     cfg["eagle_fc_bias"] = True

cfg = {"model_type": "eagle", "model": cfg}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment