Skip to content

Instantly share code, notes, and snippets.

@takahirom
Created November 9, 2024 13:23
Show Gist options
  • Save takahirom/448d2c1e3abb0f76a9bbe4b2983052e7 to your computer and use it in GitHub Desktop.
Save takahirom/448d2c1e3abb0f76a9bbe4b2983052e7 to your computer and use it in GitHub Desktop.
Softmax calculation in Python in Japanese
import torch
from transformers import top_k_top_p_filtering
logits = torch.tensor([2.0, 1.8, 1.5, 1.0, 0.5])
words = ["頑張る", "成長する", "学ぶ", "健康", "新しい"]
def get_probabilities_hf(temperature, k=None, p=None):
scaled_scores = logits / temperature
# kとpがNoneの場合、デフォルト値を設定
if k is None:
k = 0 # Top-Kフィルタリングを行わない
if p is None:
p = 1.0 # Top-Pフィルタリングを行わない
# logitsを2次元テンソルに変換(バッチ次元を追加)
scaled_scores = scaled_scores.unsqueeze(0) # shape: [1, vocab_size]
# logitsをフィルタリング
filtered_logits = top_k_top_p_filtering(
scaled_scores, top_k=k, top_p=p, filter_value=float('-inf')
)
# フィルタリング後のlogitsを再び1次元テンソルに変換
filtered_logits = filtered_logits.squeeze(0) # shape: [vocab_size]
probabilities = torch.softmax(filtered_logits, dim=0)
# 結果を表示
for word, prob in zip(words, probabilities):
bar_length = int(prob.item() * 20)
bar = "#" * bar_length
print(f"{word:<10} ({prob.item() * 100:.1f}%) {bar}")
print("\n")
# テスト実行
print("通常のサンプリング:")
get_probabilities_hf(temperature=1.0)
print("Top-Kサンプリング (K=3):")
get_probabilities_hf(temperature=1.0, k=3)
print("Top-Pサンプリング (P=0.5):")
get_probabilities_hf(temperature=1.0, p=0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment