Last active
September 27, 2023 01:55
-
-
Save the-crypt-keeper/8d781a12ee515903edc89ef69383570f to your computer and use it in GitHub Desktop.
llama2 chat prompt format reverse engineering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# this is adapted from https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L213 | |
# the tokenizer is replaced with ord() to make it easier to see whats actually happening | |
from typing_extensions import TypedDict, Literal | |
from typing import List, Optional | |
Role = Literal["system", "user", "assistant"] | |
class Message(TypedDict): | |
role: Role | |
content: str | |
class CompletionPrediction(TypedDict, total=False): | |
generation: str | |
tokens: List[str] # not required | |
logprobs: List[float] # not required | |
class ChatPrediction(TypedDict, total=False): | |
generation: Message | |
tokens: List[str] # not required | |
logprobs: List[float] # not required | |
Dialog = List[Message] | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
DEFAULT_SYSTEM_PROMPT = """\ | |
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" | |
def encode(str, bos, eos): | |
s = '<s>' if bos else '' | |
s += str | |
s += '</s>' if eos else '' | |
return [ord(x) for x in s] | |
def chat_completion( | |
dialogs: List[Dialog], | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
max_gen_len: Optional[int] = None, | |
logprobs: bool = False, | |
) -> List[ChatPrediction]: | |
#if max_gen_len is None: | |
# max_gen_len = self.model.params.max_seq_len - 1 | |
prompt_tokens = [] | |
for dialog in dialogs: | |
if dialog[0]["role"] != "system": | |
dialog = [ | |
{ | |
"role": "system", | |
"content": DEFAULT_SYSTEM_PROMPT, | |
} | |
] + dialog | |
dialog = [ | |
{ | |
"role": dialog[1]["role"], | |
"content": B_SYS | |
+ dialog[0]["content"] | |
+ E_SYS | |
+ dialog[1]["content"], | |
} | |
] + dialog[2:] | |
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( | |
[msg["role"] == "assistant" for msg in dialog[1::2]] | |
), ( | |
"model only supports 'system', 'user' and 'assistant' roles, " | |
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)" | |
) | |
dialog_tokens: List[int] = sum( | |
[ | |
encode( | |
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", | |
bos=True, | |
eos=True, | |
) | |
for prompt, answer in zip( | |
dialog[::2], | |
dialog[1::2], | |
) | |
], | |
[], | |
) | |
assert ( | |
dialog[-1]["role"] == "user" | |
), f"Last message must be from user, got {dialog[-1]['role']}" | |
dialog_tokens += encode( | |
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", | |
bos=True, | |
eos=False, | |
) | |
prompt_tokens.append(dialog_tokens) | |
return prompt_tokens | |
d1 = [Message(role="user", content="<prompt>")] | |
p1 = chat_completion([d1]) | |
print(''.join([chr(x) for x in p1[0]])) | |
d2 = [Message(role="user", content="<prompt>"), Message(role="assistant", content="<answer>"), Message(role="user", content="<prompt-second>")] | |
p2 = chat_completion([d2]) | |
print(''.join([chr(x) for x in p2[0]])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment