Skip to content

Instantly share code, notes, and snippets.

@ppsdatta
Created April 6, 2023 15:25
Show Gist options
  • Save ppsdatta/bd38e7f0507b9c4ff1a7db8be544a94c to your computer and use it in GitHub Desktop.
Save ppsdatta/bd38e7f0507b9c4ff1a7db8be544a94c to your computer and use it in GitHub Desktop.
A simple Perceptron in Scala
import java.util.Random
class Perceptron(weights: Array[Double], learningRate: Double = 0.01, bias: Double = 1.0) {
def fit(input: Array[Array[Double]], labels: Array[Double]): Array[Double] =
fitWithBias(input, labels)
def fitWithBias(data: Array[Array[Double]], labels: Array[Double], epochs: Int = 1000): Array[Double] =
train(addBias(data), labels, epochs)
def train(data: Array[Array[Double]], labels: Array[Double], epochs: Int): Array[Double] =
(0 until epochs).foldLeft(weights)((w, _) => trainIteration(w, data, labels))
def trainIteration(weights: Array[Double], data: Array[Array[Double]], labels: Array[Double]): Array[Double] =
data.zip(labels).foldLeft(weights)((w, p) => p match {
case (input, label) =>
val p = activate(dotProduct(input, weights))
val err = label - p
val updatedWeights = input.map(i => i * learningRate * err)
w.zip(updatedWeights).map(p => p._1 + p._2)
})
def activate(x: Double): Double =
if (x >= 0) 1 else 0
def dotProduct(v1: Array[Double], v2: Array[Double]): Double =
v1.zip(v2).map(p => p._1 * p._2).sum
def addBias(data: Array[Array[Double]]): Array[Array[Double]] =
data.map(arr => addRowBias(arr))
def addRowBias(input: Array[Double]): Array[Double] =
input :+ bias
def predict(weights: Array[Double], input: Array[Double]): Double =
activate(dotProduct(weights, addRowBias(input)))
}
val rnd = new Random()
def randomWeights(n: Int): Array[Double] =
(0 until n).map(_ => rnd.nextDouble(0, 1)).toArray
val input: Array[Array[Double]] = Array(Array(1, 0), Array(1, 1), Array(0, 1), Array(0, 0))
val labels: Array[Double] = Array(0, 1, 0, 0) // AND
//val labels: Array[Double] = Array(1, 1, 1, 0) // OR
//val labels: Array[Double] = Array(1, 0, 1, 0) // XOR
val weights = randomWeights(3)
val p = new Perceptron(weights)
val ws = p.fit(input, labels)
p.predict(ws, Array(0, 1))
p.predict(ws, Array(1, 1))
p.predict(ws, Array(1, 0))
p.predict(ws, Array(0, 0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment