Skip to content

Instantly share code, notes, and snippets.

@PirosB3
Created November 30, 2014 11:31
Show Gist options
  • Save PirosB3/ef5a8aaa5a90301cdb1c to your computer and use it in GitHub Desktop.
Save PirosB3/ef5a8aaa5a90301cdb1c to your computer and use it in GitHub Desktop.
use std::io::BufferedReader;
use std::collections::HashMap;
use std::io::File;
use std::from_str;
use std::rand::distributions::{Exp, IndependentSample};
#[deriving(Clone)]
struct DataEntry {
features: Vec<f64>,
target: f64
}
struct LogisticRegression {
dataset: Vec<DataEntry>,
weights: Vec<f64>
}
pub fn vector_vector_mul(data_entry: &[f64], weights: &[f64]) -> f64 {
let mut total: f64 = weights[0];
let len_weights = weights.len() - 1;
let slice_of_weights = weights.slice(1, len_weights);
for i in range(0u, len_weights) {
total += (data_entry[i] * weights[i]);
}
total
}
pub fn sigmoid(x : f64) -> f64 {
let minus_x = -x;
1.0 / (1.0 + minus_x.exp2())
}
impl LogisticRegression {
pub fn from_data_entries(dataset: &Vec<DataEntry>) -> LogisticRegression {
let first = dataset.get(0).unwrap();
let len = first.features.len() + 1;
let mut weights : Vec<f64> = Vec::new();
for i in range(0, len) {
weights.push(1.0);
}
LogisticRegression{
dataset: dataset.clone(),
weights: weights
}
}
pub fn train(&mut self, epochs: int) {
for i in range(0i, epochs) {
// Calculate total error
{
let mut error : f64 = 0.0;
for ds in self.dataset.iter() {
let reality = vector_vector_mul(
ds.feature_vector(),
self.weights.as_slice()
);
let smoothed_reality = sigmoid(reality);
let delta = ds.target - smoothed_reality;
error += delta;
}
let loss = error * 0.001;
for w in self.weights.iter_mut() {
*w = *w - loss;
}
println!("{}", error);
}
}
}
}
impl DataEntry {
pub fn new(splitted_data: &Vec<&str>, target: f64) -> DataEntry {
let features: Vec<f64> = splitted_data.iter().filter_map(|el| {
from_str(*el)
}).collect();
DataEntry{
features: features,
target: target
}
}
pub fn feature_vector<'a>(&'a self) -> &'a [f64] {
self.features.as_slice()
}
}
fn main() {
let path = Path::new("iris.data");
let mut file = BufferedReader::new(File::open(&path));
let mut dataset : Vec<DataEntry> = Vec::new();
{
let mut counter: f64 = 0.0;
let mut target_map : HashMap<String, f64> = HashMap::new();
for line in file.lines() {
let data = line.unwrap();
let trimmed_data = data.trim();
let mut splitted_data : Vec<&str> = trimmed_data.as_slice().split(',').collect();
let class_type = splitted_data.pop().unwrap().to_string();
let class_target = {
if target_map.contains_key(&class_type) {
target_map[class_type]
} else {
target_map.insert(class_type.clone(), counter);
counter += 1.0;
target_map[class_type]
}
};
{
let kls = class_type.as_slice();
if kls == "Iris-setosa" || kls == "Iris-versicolor" {
dataset.push(DataEntry::new(&splitted_data, class_target));
}
}
}
}
let mut log_regr = LogisticRegression::from_data_entries(&dataset);
log_regr.train(100);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment