Last active
February 23, 2020 11:00
-
-
Save optozorax/a6315ab6fffb4450afb43392474e31de to your computer and use it in GitHub Desktop.
autograd_example.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use autograd as ag; | |
use autograd::ops::*; | |
use ag::tensor::Tensor; | |
use ag::ndarray::arr0; | |
use ag::ops::gradient_descent_ops::SGD; | |
fn value(a: &Tensor<f32>) -> f32 { | |
a.eval(&[]).unwrap().into_raw_vec()[0] | |
} | |
fn values(grads: &[Tensor<f32>]) -> Vec<f32> { | |
grads.iter().map(|g| value(g)).collect() | |
} | |
#[allow(clippy::many_single_char_names)] | |
fn main() { | |
let a = &ag::variable(arr0::<f32>(3.0)); // Triangle sides | |
let b = &ag::variable(arr0::<f32>(4.0)); | |
let c = &ag::variable(arr0::<f32>(5.0)); | |
let p = &(abs(a) + abs(b) + abs(c)); // Perimeter | |
let pp = &(p / 2.); // Half of perimeter | |
let s = &sqrt(pp*(pp-a)*(pp-b)*(pp-c)); // Area | |
let loss = &( // Loss consists of: , | |
pow(s - 6. * 6.0_f32.sqrt(), 2.) + // area must be 6*sqrt(6), | |
pow(p - 18.0, 2.) + // perimeter 18, | |
5.0 * pow(a - 5.0, 2.) // and `a` is 5. | |
); // Answer to this parameters is | |
// a=5, b=6, c=7. | |
let params = &[a, b, c]; | |
let grads = &ag::grad(&[loss], params); // Calculate gradient of this | |
dbg!(values(grads)); // Print gradient for current values | |
let adam = &SGD { lr: 0.1 }; // Create stochastic gradient descent optimizer | |
let update_ops = &adam.compute_updates(params, grads); | |
// Simulate 500 iterations of optimization | |
for _ in 0..500 { | |
ag::eval(update_ops, &[]); | |
} | |
// Print result values | |
dbg!(value(a)); | |
dbg!(value(b)); | |
dbg!(value(c)); | |
dbg!(value(p)); | |
dbg!(value(s)); | |
dbg!(value(loss)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment