Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lukemetz/9def09cee15a62996545 to your computer and use it in GitHub Desktop.
Save lukemetz/9def09cee15a62996545 to your computer and use it in GitHub Desktop.
Indico code samples
#pragma once
#include <armadillo>
#include <algorithm>
#include <random>
#include <cassert>
#include "net.hpp"
template <typename activation = Logistic, typename error = Squared_Error>
void train_online(FeedForward_Network<activation, error>& network,
arma::Mat<float> inputs, arma::Mat<float> targets, float learning_rate) {
for (int i = 0; i < targets.n_rows; ++i) {
calculate_activation(network, inputs.row(i));
backprop(network, targets.row(i), learning_rate);
}
}
template <typename activation, typename error>
void train_batch(FeedForward_Network<activation, error>& network,
arma::Mat<float> inputs, arma::Mat<float> targets, int batch_size, float learning_rate) {
network.resize_activation(batch_size);
int batches_in_train = targets.n_rows/batch_size - 1;
for (int i = 0; i < batches_in_train; ++i) {
arma::Mat<float> input_slice = inputs.rows(i*batch_size, (i+1) * batch_size-1);
calculate_activation(network, input_slice);
arma::Mat<float> target_slice = targets.rows(i*batch_size, (i+1) * batch_size-1);
backprop(network, target_slice, learning_rate);
}
}
//Randomize weights in a network.
template <typename activation, typename error>
void randomize(FeedForward_Network<activation, error>& network, float standard_deviation = 0.05) {
std::default_random_engine generator;
std::normal_distribution<float> distribution(0, standard_deviation);
auto random_num = [&]() {return distribution(generator);};
for (int i=0; i < network.weights.size(); ++i) {
network.weights[i].imbue(random_num);
network.last_weights[i].imbue(random_num);
}
}
template <typename arma_t, typename activation, typename error>
void backprop(FeedForward_Network<activation, error> &network,
arma_t target, float learning_rate = 0.8f, float momentum = 0.8f) {
//Calculate deltas
//output delta first
network.deltas.back() = error::error_dir(target, network.activations.back()) % activation::activation_dir(network.activations.back());
//rest of the delta
for (int i = network.deltas.size() - 2; i >= 0; --i) {
network.deltas[i] = (network.deltas[i+1] * network.weights[i+1].t()) % activation::activation_dir(network.activations[i+1]);
}
//update weights
for (int i=0; i < network.weights.size(); ++i) {
auto & standard_piece = (1 - momentum) * learning_rate * (network.deltas[i].t() * network.activations[i]).t();
auto & momentum_piece = momentum * (network.weights[i] - network.last_weights[i]);
arma::Mat<float> delta_weights = standard_piece + momentum_piece;
network.last_weights[i] = network.weights[i];
network.weights[i] += delta_weights;
}
}
template <typename arma_t, typename activation, typename error>
void calculate_activation(FeedForward_Network<activation, error>& network,
arma_t input) {
network.activations[0] = input;
for(int i=1; i < network.activations.size(); ++i) {
network.activations[i] = network.activations[i-1] * network.weights[i-1];
network.activations[i] = activation::activation(network.activations[i]);
}
}
//Scoring function for classification.
inline double classify_percent_score(arma::Mat<float> result, arma::Mat<float> correct) {
assert(result.n_cols == correct.n_cols);
int num_correct = 0;
for (int i=0; i < result.n_rows; ++i) {
auto sort_vec = arma::sort_index(result.row(i), 1);
if (correct.row(i)[sort_vec[0]] == 1) {
num_correct += 1;
}
}
return static_cast<float>(num_correct) / static_cast<float>(result.n_rows);
}
//Scoring function that calculates the difference in squares between two matrices.
inline float squared_diff(arma::Mat<float> result, arma::Mat<float> correct) {
assert(result.n_cols == correct.n_cols);
auto error_diff = correct - result;
return arma::accu(error_diff % error_diff);
}
use libc::{c_long, size_t};
use std::c_str::CString;
pub use base::{PyObject, ToPyType, FromPyType, PyState, PyIterator};
pub use ffi::{PythonCAPI, PyObjectRaw};
pub use base::{PyError,
FromTypeConversionError,
ToTypeConversionError,
NullPyObject};
macro_rules! prim_pytype (
($base_type:ty, $cast_type:ty, $to:ident, $back:ident, $check:ident) => (
impl ToPyType for $base_type {
fn to_py_object<'a>(&self, state : &'a PyState) -> Result<PyObject<'a>, PyError> {
unsafe {
let raw = state.$to(*self as $cast_type);
if raw.is_not_null() && state.$check(raw) > 0 {
Ok(PyObject::new(state, raw))
} else {
Err(ToTypeConversionError)
}
}
}
}
impl FromPyType for $base_type {
fn from_py_object(state : &PyState, py_object : PyObject) -> Result<$base_type, PyError> {
unsafe {
if py_object.raw.is_not_null() && state.$check(py_object.raw) > 0 {
Ok(state.$back(py_object.raw) as $base_type)
} else {
Err(FromTypeConversionError)
}
}
}
}
)
)
prim_pytype!(f64, f64, PyFloat_FromDouble, PyFloat_AsDouble, PyFloat_Check)
prim_pytype!(f32, f64, PyFloat_FromDouble, PyFloat_AsDouble, PyFloat_Check)
prim_pytype!(i64, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(i32, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(int, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(uint, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(u8, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(u32, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
prim_pytype!(u64, c_long, PyInt_FromLong, PyInt_AsLong, PyInt_Check)
#[cfg(test)]
mod test {
use base::PyState;
use super::{ToPyType, FromPyType, NoArgs};
macro_rules! try_or_fail (
($e:expr) => (match $e { Ok(e) => e, Err(e) => fail!("{}", e) })
)
macro_rules! num_to_py_object_and_back (
($t:ty, $func_name:ident) => (
#[test]
fn $func_name() {
let py = PyState::new();
let value = 123i as $t;
let py_object = try_or_fail!(value.to_py_object(&py));
let returned = try_or_fail!(py.from_py_object::<$t>(py_object));
assert_eq!(returned, 123i as $t);
}
)
)
num_to_py_object_and_back!(f64, to_from_f64)
num_to_py_object_and_back!(f32, to_from_f32)
num_to_py_object_and_back!(i64, to_from_i64)
num_to_py_object_and_back!(i32, to_from_i32)
num_to_py_object_and_back!(int, to_from_int)
num_to_py_object_and_back!(uint, to_from_uint)
num_to_py_object_and_back!(u8, to_from_u8)
num_to_py_object_and_back!(u32, to_from_32)
num_to_py_object_and_back!(u64, to_from_54)
}
@lukemetz
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment