Created
November 9, 2024 13:21
-
-
Save takahirom/6ec47fb00ab3fcc8289f281a3fc912cc to your computer and use it in GitHub Desktop.
Softmax calculation in Kotlin in Japanese
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val logits = arrayOf(2.0, 1.8, 1.5, 1.0, 0.5) | |
val words = arrayOf("頑張る", "成長する", "学ぶ", "健康", "新しい") | |
fun getProbabilitiesHf(temperature: Double, k: Int? = null, p: Double? = null) { | |
val defaultK = k ?: 0 // kがnullの場合は0(フィルタリングなし) | |
val defaultP = p ?: 1.0 // pがnullの場合は1.0(フィルタリングなし) | |
// スコアを温度でスケーリング | |
val scaledScores = logits.map { it / temperature }.toDoubleArray() | |
println("scaledScores:" + scaledScores.joinToString()) | |
// Top-Kフィルタリングを適用 | |
val topKFilteredScores = if (defaultK > 0) topKFiltering(scaledScores, defaultK) else scaledScores | |
// Top-Pフィルタリングを適用 | |
val topPFilteredScores = if (defaultP < 1.0) topPFiltering(topKFilteredScores, defaultP) else topKFilteredScores | |
// ソフトマックスで確率を計算 | |
val probabilities = softmax(topPFilteredScores) | |
// 結果を表示 | |
for (i in probabilities.indices) { | |
val word = words[i] | |
val prob = probabilities[i] | |
val barLength = (prob * 20).toInt() | |
val bar = "#".repeat(barLength) | |
println(String.format("%-10s (%.1f%%) %s", word, prob * 100, bar)) | |
} | |
println() | |
} | |
fun softmax(logits: DoubleArray): DoubleArray { | |
val maxLogit = logits.maxOrNull() ?: 0.0 | |
val exps = logits.map { Math.exp(it - maxLogit) }.toDoubleArray() | |
val sumExps = exps.sum() | |
return exps.map { it / sumExps }.toDoubleArray() | |
} | |
fun topKFiltering(logits: DoubleArray, k: Int): DoubleArray { | |
if (k <= 0) return logits | |
val sortedIndices = logits.indices.sortedByDescending { logits[it] } | |
val thresholdIndex = sortedIndices.getOrNull(k - 1) | |
val threshold = if (thresholdIndex != null) logits[thresholdIndex] else Double.NEGATIVE_INFINITY | |
return logits.map { if (it < threshold) Double.NEGATIVE_INFINITY else it }.toDoubleArray() | |
} | |
fun topPFiltering(logits: DoubleArray, p: Double): DoubleArray { | |
if (p >= 1.0) return logits | |
val sortedLogits = logits.sortedByDescending { it }.toDoubleArray() | |
val probs = softmax(sortedLogits) | |
var cumulativeProb = 0.0 | |
var cutoffIndex = sortedLogits.size | |
for (i in probs.indices) { | |
cumulativeProb += probs[i] | |
if (cumulativeProb >= p) { | |
cutoffIndex = i + 1 | |
break | |
} | |
} | |
val thresholdLogit = sortedLogits.getOrNull(cutoffIndex - 1) ?: Double.NEGATIVE_INFINITY | |
return logits.map { if (it < thresholdLogit) Double.NEGATIVE_INFINITY else it }.toDoubleArray() | |
} | |
logits.forEachIndexed { index, d -> | |
println(words[index] + " : " + d) | |
} | |
println("通常のサンプリング:") | |
getProbabilitiesHf(temperature = 1.0) | |
println("通常のサンプリング(temperature = 2.0):") | |
getProbabilitiesHf(temperature = 2.0) | |
println("通常のサンプリング(temperature = 0.1):") | |
getProbabilitiesHf(temperature = 0.1) | |
println("Top-Kサンプリング (K=3):") | |
getProbabilitiesHf(temperature = 1.0, k = 3) | |
println("Top-Pサンプリング (P=0.5):") | |
getProbabilitiesHf(temperature = 1.0, p = 0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment