Created
May 22, 2024 03:48
-
-
Save TruncatedDinoSour/a258567b1676c795fc39c6d9ea2c613b to your computer and use it in GitHub Desktop.
A machine learning model example in Java: prediction of f(x) = x/pi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Me and #root:ari.lt people on Matrix are messing aroung with | |
* Java and I went from Hello World to Fibonacci, to Machine Learning. | |
* Enjoy | |
* | |
* License: Unlicense | |
*/ | |
import java.util.Random; | |
public class ML { | |
public static double rate = 0.001; | |
public static int epochs = 1024; | |
public static double loss(double[][] t, double p0) { | |
double l; | |
int idx; | |
l = 0.0; | |
for (idx = 0; idx < t.length; ++idx) { | |
double in = t[idx][0]; | |
double out = t[idx][1]; | |
double res = in * p0; /* f(x) = x * p0 */ | |
double delta = res - out; | |
System.out.printf("[%f] Expected %f => %f, got %f%n", l, in, out, | |
res); | |
l += delta * delta; | |
} | |
return l / t.length; | |
} | |
/* Returns an estimate of a deriv(loss), more accurate as e approaches 0, | |
* although slower learning */ | |
public static double loss_deriv(double[][] t, double p0, double e) { | |
return (loss(t, p0 + e) - loss(t, p0)) / e; | |
} | |
public static void main(String[] args) { | |
int idx; | |
double p0; | |
double[][] t = new double[16][2]; | |
double e; | |
Random rand = new Random(); | |
e = 0.001; | |
System.out.println("Doing epic training data generation..."); | |
for (idx = 0; idx < t.length; ++idx) { | |
t[idx][0] = rand.nextDouble() * 20.0; | |
t[idx][1] = t[idx][0] / Math.PI; | |
System.out.printf("%d: %f => %f%n", idx, t[idx][0], t[idx][1]); | |
} | |
p0 = rand.nextDouble() * 10.0; | |
System.out.printf("Learning for %d epochs...", epochs); | |
for (idx = 0; idx < epochs; ++idx) { | |
System.out.printf("%d: p0 = %f (loss = %f), e = %f%n", idx, p0, | |
loss(t, p0), e); | |
p0 += e; | |
e -= rate * loss_deriv(t, p0, e); | |
} | |
System.out.printf("Final: p0 = %f (loss = %f), e = %f%n", p0, | |
loss(t, p0), e); | |
for (idx = 1; idx <= 16; ++idx) { | |
System.out.printf("%d / pi = %f (expected %f)%n", idx, idx * p0, | |
idx / Math.PI); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment