Created
November 2, 2024 10:02
-
-
Save vndee/3f1eef379bb2f8882532c0e2ac8ade45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
def plot_temperature_effects(logits, temperatures): | |
""" | |
Visualize how different temperatures affect probability distribution | |
""" | |
plt.figure(figsize=(12, 6)) | |
x = np.arange(len(logits)) | |
for temp in temperatures: | |
scaled_logits = [l/temp for l in logits] | |
probs = softmax(scaled_logits) | |
plt.plot(x, probs, label=f'T={temp}', marker='o') | |
plt.xlabel('Token Index') | |
plt.ylabel('Probability') | |
plt.title('Effect of Temperature on Token Probabilities') | |
plt.legend() | |
plt.grid(True) | |
plt.show() | |
# Example usage | |
logits = [5.0, 2.0, 1.0, 0.0, -1.0] # Example logits | |
temperatures = [0.1, 0.5, 1.0, 2.0] | |
plot_temperature_effects(logits, temperatures) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment