Skip to content

Instantly share code, notes, and snippets.

@mlabonne
Last active June 4, 2023 12:19
Show Gist options
  • Save mlabonne/d6f97827b7369f133a43adb32e1759e0 to your computer and use it in GitHub Desktop.
Save mlabonne/d6f97827b7369f133a43adb32e1759e0 to your computer and use it in GitHub Desktop.
Shows the influence of temperature on the probabilities output by the softmax function
import numpy as np
import matplotlib.pyplot as plt
def softmax(x, temperature=1.0):
e_x = np.exp(x / temperature)
return e_x / e_x.sum(axis=0)
logits = np.array([1.5, -1.8, 0.9, -3.2])
temperatures = [1.0, 0.5, 0.1]
# Set up the matplotlib figure
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
class_labels = ["A", "B", "C", "D"]
for i, temp in enumerate(temperatures):
# Compute probabilities
probabilities = softmax(logits, temp)
# Bar chart
axes[i].bar(range(len(logits)), probabilities, tick_label=class_labels)
# Write the values on top of each bar
for j in range(len(logits)):
axes[i].text(j, probabilities[j], round(probabilities[j], 2), ha='center', va='bottom')
# Set title
axes[i].set_title(f'Temperature = {temp}')
# Set x and y labels
axes[i].set_xlabel('Class')
axes[i].set_ylabel('Probability')
# Remove upper and right borders
axes[i].spines['right'].set_visible(False)
axes[i].spines['top'].set_visible(False)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment