Last active
March 14, 2024 08:41
-
-
Save ed1d1a8d/eb98e6ea47646589b2a8423bc9ac992e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import openai | |
import scipy.special | |
import tiktoken | |
def get_top_chat_logprobs( | |
model: str, | |
messages: list[dict[str, str]], | |
seed: int = 42, | |
n_logprobs: int = 20, | |
) -> dict[int, tuple[float, str, str]]: | |
""" | |
Returns a dict mapping token_idx to (logprob, token_str, system_fingerprint) | |
for the top n_logprobs tokens. | |
Supports a maximum of 305 logprobs, but is flaky for 290+ logprobs | |
(i.e. 290-305 logprobs will sometimes work, sometimes error). 306 logprobs | |
and above will always error (unless OpenAI API changes). | |
""" | |
tokenizer = tiktoken.encoding_for_model(model) | |
assert 1 <= n_logprobs <= tokenizer.n_vocab | |
client = openai.Client() | |
def query_api(**kwargs): | |
return client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=0, | |
max_tokens=1, | |
n=1, | |
logprobs=True, | |
top_logprobs=5, | |
seed=seed, | |
**kwargs, | |
) | |
base_resp = query_api() | |
# Maps token_idx bytes to (logprob, token_str). | |
logprob_dict: dict[int, tuple[float, str, str]] = { | |
tokenizer.encode_single_token(bytes(top_logprobs.bytes)): ( | |
top_logprobs.logprob, | |
top_logprobs.token, | |
base_resp.system_fingerprint, | |
) | |
for top_logprobs in base_resp.choices[0] | |
.logprobs.content[0] | |
.top_logprobs | |
} | |
BIAS = -100 | |
while len(logprob_dict) < n_logprobs: | |
log_masked_sum = scipy.special.logsumexp( | |
[logprob for logprob, _, _ in logprob_dict.values()] | |
) | |
unmasked_sum = -scipy.special.expm1(log_masked_sum) | |
log_unmasked_sum = np.log(unmasked_sum) | |
resp = query_api( | |
logit_bias={token_idx: BIAS for token_idx in logprob_dict.keys()} | |
) | |
for top_logprob in resp.choices[0].logprobs.content[0].top_logprobs: | |
if len(logprob_dict) >= n_logprobs: | |
break | |
token_str = top_logprob.token | |
if token_str in ["<|end|>", "<|endoftext|>"]: | |
token_idx = tokenizer.eot_token | |
else: | |
token_idx = tokenizer.encode_single_token( | |
bytes(top_logprob.bytes) | |
) | |
biased_logprob = top_logprob.logprob | |
true_logprob = biased_logprob + np.logaddexp( | |
log_masked_sum + BIAS, log_unmasked_sum | |
) | |
logprob_dict[token_idx] = ( | |
true_logprob, | |
token_str, | |
resp.system_fingerprint, | |
) | |
print(len(logprob_dict)) | |
return logprob_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment