Created
April 6, 2023 15:25
-
-
Save ppsdatta/bd38e7f0507b9c4ff1a7db8be544a94c to your computer and use it in GitHub Desktop.
A simple Perceptron in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Random | |
class Perceptron(weights: Array[Double], learningRate: Double = 0.01, bias: Double = 1.0) { | |
def fit(input: Array[Array[Double]], labels: Array[Double]): Array[Double] = | |
fitWithBias(input, labels) | |
def fitWithBias(data: Array[Array[Double]], labels: Array[Double], epochs: Int = 1000): Array[Double] = | |
train(addBias(data), labels, epochs) | |
def train(data: Array[Array[Double]], labels: Array[Double], epochs: Int): Array[Double] = | |
(0 until epochs).foldLeft(weights)((w, _) => trainIteration(w, data, labels)) | |
def trainIteration(weights: Array[Double], data: Array[Array[Double]], labels: Array[Double]): Array[Double] = | |
data.zip(labels).foldLeft(weights)((w, p) => p match { | |
case (input, label) => | |
val p = activate(dotProduct(input, weights)) | |
val err = label - p | |
val updatedWeights = input.map(i => i * learningRate * err) | |
w.zip(updatedWeights).map(p => p._1 + p._2) | |
}) | |
def activate(x: Double): Double = | |
if (x >= 0) 1 else 0 | |
def dotProduct(v1: Array[Double], v2: Array[Double]): Double = | |
v1.zip(v2).map(p => p._1 * p._2).sum | |
def addBias(data: Array[Array[Double]]): Array[Array[Double]] = | |
data.map(arr => addRowBias(arr)) | |
def addRowBias(input: Array[Double]): Array[Double] = | |
input :+ bias | |
def predict(weights: Array[Double], input: Array[Double]): Double = | |
activate(dotProduct(weights, addRowBias(input))) | |
} | |
val rnd = new Random() | |
def randomWeights(n: Int): Array[Double] = | |
(0 until n).map(_ => rnd.nextDouble(0, 1)).toArray | |
val input: Array[Array[Double]] = Array(Array(1, 0), Array(1, 1), Array(0, 1), Array(0, 0)) | |
val labels: Array[Double] = Array(0, 1, 0, 0) // AND | |
//val labels: Array[Double] = Array(1, 1, 1, 0) // OR | |
//val labels: Array[Double] = Array(1, 0, 1, 0) // XOR | |
val weights = randomWeights(3) | |
val p = new Perceptron(weights) | |
val ws = p.fit(input, labels) | |
p.predict(ws, Array(0, 1)) | |
p.predict(ws, Array(1, 1)) | |
p.predict(ws, Array(1, 0)) | |
p.predict(ws, Array(0, 0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment