Skip to content

Instantly share code, notes, and snippets.

@vndee
Created November 2, 2024 10:02
Show Gist options
  • Save vndee/3f1eef379bb2f8882532c0e2ac8ade45 to your computer and use it in GitHub Desktop.
Save vndee/3f1eef379bb2f8882532c0e2ac8ade45 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
def plot_temperature_effects(logits, temperatures):
"""
Visualize how different temperatures affect probability distribution
"""
plt.figure(figsize=(12, 6))
x = np.arange(len(logits))
for temp in temperatures:
scaled_logits = [l/temp for l in logits]
probs = softmax(scaled_logits)
plt.plot(x, probs, label=f'T={temp}', marker='o')
plt.xlabel('Token Index')
plt.ylabel('Probability')
plt.title('Effect of Temperature on Token Probabilities')
plt.legend()
plt.grid(True)
plt.show()
# Example usage
logits = [5.0, 2.0, 1.0, 0.0, -1.0] # Example logits
temperatures = [0.1, 0.5, 1.0, 2.0]
plot_temperature_effects(logits, temperatures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment