Skip to content

Instantly share code, notes, and snippets.

@takahirom
Created November 9, 2024 13:21
Show Gist options
  • Save takahirom/6ec47fb00ab3fcc8289f281a3fc912cc to your computer and use it in GitHub Desktop.
Save takahirom/6ec47fb00ab3fcc8289f281a3fc912cc to your computer and use it in GitHub Desktop.
Softmax calculation in Kotlin in Japanese
val logits = arrayOf(2.0, 1.8, 1.5, 1.0, 0.5)
val words = arrayOf("頑張る", "成長する", "学ぶ", "健康", "新しい")
fun getProbabilitiesHf(temperature: Double, k: Int? = null, p: Double? = null) {
val defaultK = k ?: 0 // kがnullの場合は0(フィルタリングなし)
val defaultP = p ?: 1.0 // pがnullの場合は1.0(フィルタリングなし)
// スコアを温度でスケーリング
val scaledScores = logits.map { it / temperature }.toDoubleArray()
println("scaledScores:" + scaledScores.joinToString())
// Top-Kフィルタリングを適用
val topKFilteredScores = if (defaultK > 0) topKFiltering(scaledScores, defaultK) else scaledScores
// Top-Pフィルタリングを適用
val topPFilteredScores = if (defaultP < 1.0) topPFiltering(topKFilteredScores, defaultP) else topKFilteredScores
// ソフトマックスで確率を計算
val probabilities = softmax(topPFilteredScores)
// 結果を表示
for (i in probabilities.indices) {
val word = words[i]
val prob = probabilities[i]
val barLength = (prob * 20).toInt()
val bar = "#".repeat(barLength)
println(String.format("%-10s (%.1f%%) %s", word, prob * 100, bar))
}
println()
}
fun softmax(logits: DoubleArray): DoubleArray {
val maxLogit = logits.maxOrNull() ?: 0.0
val exps = logits.map { Math.exp(it - maxLogit) }.toDoubleArray()
val sumExps = exps.sum()
return exps.map { it / sumExps }.toDoubleArray()
}
fun topKFiltering(logits: DoubleArray, k: Int): DoubleArray {
if (k <= 0) return logits
val sortedIndices = logits.indices.sortedByDescending { logits[it] }
val thresholdIndex = sortedIndices.getOrNull(k - 1)
val threshold = if (thresholdIndex != null) logits[thresholdIndex] else Double.NEGATIVE_INFINITY
return logits.map { if (it < threshold) Double.NEGATIVE_INFINITY else it }.toDoubleArray()
}
fun topPFiltering(logits: DoubleArray, p: Double): DoubleArray {
if (p >= 1.0) return logits
val sortedLogits = logits.sortedByDescending { it }.toDoubleArray()
val probs = softmax(sortedLogits)
var cumulativeProb = 0.0
var cutoffIndex = sortedLogits.size
for (i in probs.indices) {
cumulativeProb += probs[i]
if (cumulativeProb >= p) {
cutoffIndex = i + 1
break
}
}
val thresholdLogit = sortedLogits.getOrNull(cutoffIndex - 1) ?: Double.NEGATIVE_INFINITY
return logits.map { if (it < thresholdLogit) Double.NEGATIVE_INFINITY else it }.toDoubleArray()
}
logits.forEachIndexed { index, d ->
println(words[index] + " : " + d)
}
println("通常のサンプリング:")
getProbabilitiesHf(temperature = 1.0)
println("通常のサンプリング(temperature = 2.0):")
getProbabilitiesHf(temperature = 2.0)
println("通常のサンプリング(temperature = 0.1):")
getProbabilitiesHf(temperature = 0.1)
println("Top-Kサンプリング (K=3):")
getProbabilitiesHf(temperature = 1.0, k = 3)
println("Top-Pサンプリング (P=0.5):")
getProbabilitiesHf(temperature = 1.0, p = 0.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment