Created
May 20, 2023 03:58
-
-
Save clarkmcc/28f6e3208ae1b63e975a0b7ca9d8a8f0 to your computer and use it in GitHub Desktop.
In trying to wrap my head around the basics of machine learning, this example really helped me to get a small taste for how machine learning frameworks like TensorFlow work their magic under the hood.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rand::random; | |
/// In this example, we'll attempt to train a single-neuron model to fit | |
/// a linear line. The slope of this line is 2, and can by expressed by | |
/// the equation y = 2x. We'll use gradient descent to find the slope of | |
/// the line, and then we'll use the slope to predict the output for | |
/// a given input. | |
/// | |
/// The model won't be able to perfectly fit the line, but it will be | |
/// able to get very close. The idea here is we want the model to be | |
/// able to _approximate_ the line. | |
static TRAIN: [[f64; 2]; 6] = [ | |
[0.0, 0.0], | |
[1.0, 2.0], | |
[2.0, 4.0], | |
[4.0, 8.0], | |
[5.0, 10.0], | |
[6.0, 12.0], | |
]; | |
/// Cost takes in a given weight and returns a number representing the | |
/// cost of the model. The cost is another way of saying: how accurate | |
/// is the model at predicting Y, given X? | |
/// | |
/// This cost is calculated by iterating through every X value, multiplying | |
/// it by the weight, and then subtracting the actual Y value. This gives | |
/// us the difference between the predicted Y and the actual Y. We then | |
/// square this difference to get rid of the negative sign, and to make | |
/// the cost "more expensive" if the difference is large, i.e. we magnify | |
/// numbers that are farther from the actual value. | |
fn cost(w: f64) -> f64 { | |
// Start with a cost of 0, indicating a perfect model. | |
let mut cost = 0.0; | |
// Iterate through our training data, one step at a time. | |
for step in TRAIN.iter() { | |
// Predict a Y value by multiplying the X value by the weight. | |
let x = step[0]; | |
let y = x * w; | |
// Calculate the loss as the difference between the predicted Y | |
// and the actual Y. | |
let loss = y - step[1]; | |
// Square the loss and add it to the total cost. | |
cost += loss*loss; | |
} | |
cost | |
} | |
/// Optimize is responsible for finding the minimum cost by changing the | |
/// weight. It does this using a simple implementation of gradient descent. | |
/// Gradient descent is a way of finding the minimum of a function by | |
/// taking the derivative of the function at a given point, and then | |
/// moving in the direction of the negative of the derivative. This | |
/// is repeated until the derivative is 0, indicating that we have | |
/// reached a minimum. | |
/// | |
/// In other words, optimize is responsible for determining the rate of | |
/// change of the cost function with respect to the weight, and then | |
/// moving the weight in the direction of the negative of that rate of | |
/// change. | |
fn optimize(learning_rate: f64, w: &mut f64) { | |
let eps = 1e-4; | |
let d = (cost(*w + eps) - cost(*w))/eps; | |
*w -= learning_rate * d; | |
} | |
fn main() { | |
// Initialize the weight to a random value. This value is going to | |
// be very bad at predicting Y given an X initially, but we will try | |
// to optimize it. | |
let mut w: f64 = random(); | |
let learning_rate = 1e-4; | |
// Optimize the weight 100,000 times. After each iteration, the weight | |
// will be slightly better at predicting Y given an X. | |
for _ in 0..100000 { | |
optimize(learning_rate, &mut w); | |
} | |
// Try predicting the output for the training data, notice how we're | |
// able to predict the output within 1/10,000th precision. We've | |
// successfully trained a single-neuron model that can fit a line! | |
for step in TRAIN.iter() { | |
let x = step[0]; | |
let y = x * w; | |
println!("x: {}, y: {}", x, y); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment