Last active
June 4, 2023 12:19
-
-
Save mlabonne/d6f97827b7369f133a43adb32e1759e0 to your computer and use it in GitHub Desktop.
Shows the influence of temperature on the probabilities output by the softmax function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def softmax(x, temperature=1.0): | |
| e_x = np.exp(x / temperature) | |
| return e_x / e_x.sum(axis=0) | |
| logits = np.array([1.5, -1.8, 0.9, -3.2]) | |
| temperatures = [1.0, 0.5, 0.1] | |
| # Set up the matplotlib figure | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5)) | |
| class_labels = ["A", "B", "C", "D"] | |
| for i, temp in enumerate(temperatures): | |
| # Compute probabilities | |
| probabilities = softmax(logits, temp) | |
| # Bar chart | |
| axes[i].bar(range(len(logits)), probabilities, tick_label=class_labels) | |
| # Write the values on top of each bar | |
| for j in range(len(logits)): | |
| axes[i].text(j, probabilities[j], round(probabilities[j], 2), ha='center', va='bottom') | |
| # Set title | |
| axes[i].set_title(f'Temperature = {temp}') | |
| # Set x and y labels | |
| axes[i].set_xlabel('Class') | |
| axes[i].set_ylabel('Probability') | |
| # Remove upper and right borders | |
| axes[i].spines['right'].set_visible(False) | |
| axes[i].spines['top'].set_visible(False) | |
| plt.tight_layout() | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment