Skip to content

Instantly share code, notes, and snippets.

@dansalias
Created December 19, 2024 15:59
Show Gist options
  • Save dansalias/9e91399d0dea3a9e02c359a4274be6e2 to your computer and use it in GitHub Desktop.
Save dansalias/9e91399d0dea3a9e02c359a4274be6e2 to your computer and use it in GitHub Desktop.
Backpropagation in Rust
use crate::layer::Layer;
use crate::network::Network;
fn get_d_output(
output_layer: &Layer,
delta: Vec<f64>,
) -> Vec<f64> {
if output_layer.activation_functions.len() > 0 {
output_layer
.activation_functions
.iter()
.rfold(output_layer.get_activations(), |values, activation_function| {
activation_function.derivative(values)
})
.iter()
.zip(delta.iter())
.map(|(z, d)| z * d)
.collect()
} else {
delta
}
}
fn get_layer_gradient(
input_layer: &Layer,
delta: &Vec<f64>,
) -> Vec<(Vec<f64>, f64)> {
delta
.iter()
.map(|d| (
input_layer
.get_activations()
.iter()
.map(|a| a * d)
.collect(),
*d,
))
.collect()
}
fn get_next_delta(
output_layer: &Layer,
delta: &Vec<f64>,
) -> Vec<f64> {
delta
.iter()
.enumerate()
.map(|(i, _)| {
output_layer
.neurons
.iter()
.map(|n| n.weights[i])
.zip(delta.iter())
.map(|(w, d)| w * d)
.sum()
})
.collect()
}
pub fn get_gradient(
network: &Network,
error: Vec<f64>,
) -> Vec<Vec<(Vec<f64>, f64)>> {
let mut gradient: Vec<Vec<(Vec<f64>, f64)>> = Vec::new();
network
.get_layer_pairs()
.iter()
.rfold(error, |delta, (input_layer, output_layer)| {
let d_output = get_d_output(&output_layer, delta);
gradient.insert(0, get_layer_gradient(&input_layer, &d_output));
get_next_delta(&output_layer, &d_output)
});
gradient
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use crate::math::Relu;
use super::*;
#[test]
fn gets_gradient() {
let mut network = Network::new(
vec![
(2, vec![]),
(2, vec![]),
],
);
network.set_parameters(vec![
vec![
(vec![1.0, 1.5], 0.5),
(vec![0.5, 1.0], 0.5),
],
]);
network.get_output(&[1.0, 2.0]);
let gradient = get_gradient(
&network,
vec![0.5, 2.0],
);
assert_eq!(
gradient,
vec![
vec![
(vec![0.5, 1.0], 0.5),
(vec![2.0, 4.0], 2.0),
],
],
);
}
#[test]
fn gets_gradient_deep() {
let mut network = Network::new(
vec![
(3, vec![]),
(2, vec![Rc::new(Relu)]),
(2, vec![]),
],
);
network.set_parameters(vec![
vec![
(vec![0.5, 1.0, 1.0], 0.5),
(vec![1.0, 0.0, 0.0], 0.5),
],
vec![
(vec![0.5, 1.0], 0.5),
(vec![1.0, 1.0], 0.5),
],
]);
network.get_output(&[-1.0, 1.0, 1.0]);
let gradient = get_gradient(
&network,
vec![1.0, 0.5],
);
assert_eq!(
gradient,
vec![
vec![
(vec![-1.0, 1.0, 1.0], 1.0),
(vec![0.0, 0.0, 0.0], 0.0),
],
vec![
(vec![2.0, 0.0], 1.0),
(vec![1.0, 0.0], 0.5),
],
],
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment