Created
November 9, 2024 13:23
-
-
Save takahirom/448d2c1e3abb0f76a9bbe4b2983052e7 to your computer and use it in GitHub Desktop.
Softmax calculation in Python in Japanese
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import top_k_top_p_filtering | |
logits = torch.tensor([2.0, 1.8, 1.5, 1.0, 0.5]) | |
words = ["頑張る", "成長する", "学ぶ", "健康", "新しい"] | |
def get_probabilities_hf(temperature, k=None, p=None): | |
scaled_scores = logits / temperature | |
# kとpがNoneの場合、デフォルト値を設定 | |
if k is None: | |
k = 0 # Top-Kフィルタリングを行わない | |
if p is None: | |
p = 1.0 # Top-Pフィルタリングを行わない | |
# logitsを2次元テンソルに変換(バッチ次元を追加) | |
scaled_scores = scaled_scores.unsqueeze(0) # shape: [1, vocab_size] | |
# logitsをフィルタリング | |
filtered_logits = top_k_top_p_filtering( | |
scaled_scores, top_k=k, top_p=p, filter_value=float('-inf') | |
) | |
# フィルタリング後のlogitsを再び1次元テンソルに変換 | |
filtered_logits = filtered_logits.squeeze(0) # shape: [vocab_size] | |
probabilities = torch.softmax(filtered_logits, dim=0) | |
# 結果を表示 | |
for word, prob in zip(words, probabilities): | |
bar_length = int(prob.item() * 20) | |
bar = "#" * bar_length | |
print(f"{word:<10} ({prob.item() * 100:.1f}%) {bar}") | |
print("\n") | |
# テスト実行 | |
print("通常のサンプリング:") | |
get_probabilities_hf(temperature=1.0) | |
print("Top-Kサンプリング (K=3):") | |
get_probabilities_hf(temperature=1.0, k=3) | |
print("Top-Pサンプリング (P=0.5):") | |
get_probabilities_hf(temperature=1.0, p=0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment