Last active
November 17, 2023 23:55
-
-
Save brandonwillard/e1f41053c599bb584d4b922251cd96f5 to your computer and use it in GitHub Desktop.
Computing sequence probabilities in Outlines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import outlines.models as models | |
from outlines.text.generate.regex import choice | |
from outlines.text.generate.continuation import continuation | |
from outlines.text.generate.sample import greedy | |
def make_greedy_tracker(generator): | |
import types | |
generator.last_sequence_log_prob = None | |
generator.running_sequence_log_prob = 0.0 | |
def tracking_greedy( | |
logits: torch.DoubleTensor, samples: int, *_ | |
) -> torch.DoubleTensor: | |
next_token_ids = greedy(logits, samples) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
generator.running_sequence_log_prob += torch.log( | |
probs[:, next_token_ids.squeeze()].squeeze() | |
) | |
return next_token_ids | |
generator.sampler = tracking_greedy | |
old_postprocess_completions = generator.postprocess_completions | |
def new_postprocess_completions(self, *args, **kwargs): | |
# Reset the sequence log-probability | |
res = old_postprocess_completions(*args, **kwargs) | |
self.last_sequence_log_prob = self.running_sequence_log_prob | |
self.running_sequence_log_prob = 0.0 | |
return res | |
generator.postprocess_completions = types.MethodType( | |
new_postprocess_completions, generator | |
) | |
return generator | |
model = models.transformers("gpt2") | |
generator = make_greedy_tracker(continuation(model, max_tokens=50)) | |
choice_generator = make_greedy_tracker( | |
choice(model, ["[Bb]lue", "[Rr]ed"], max_tokens=50) | |
) | |
prompt = "Which color do you prefer: blue or red?" | |
sequence = generator(prompt) | |
print(sequence) | |
# | |
# | |
# The answer is blue. | |
# | |
# The color of the car is the color of the car. | |
# | |
# The color of the car is the color of the car. | |
# | |
# The color of the car is the color of the car. | |
print(generator.last_sequence_log_prob) | |
# tensor(-44.7725) | |
sequence = generator("Which color do you prefer: red or blue?") | |
print(sequence) | |
# The answer is: red. | |
# | |
# The red color is the color of the color of the car. It's the color of the car that's the most important. | |
# | |
# The red color is the color of the car that's the | |
print(generator.last_sequence_log_prob) | |
# tensor(-69.0348) | |
sequence = choice_generator(prompt) | |
print(sequence) | |
# Blue | |
print(choice_generator.last_sequence_log_prob) | |
# tensor(-0.9262) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment