Skip to content

Instantly share code, notes, and snippets.

@CVertex
Created July 5, 2011 09:31
Show Gist options
  • Save CVertex/1064551 to your computer and use it in GitHub Desktop.
Save CVertex/1064551 to your computer and use it in GitHub Desktop.
My first go at a SGD user
import java.util.ArrayList;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.L2;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
import com.tdunning.ch16.CategoryFeatureEncoder;
import com.tdunning.ch16.Item;
public class PersonClassifier {
private static class Person {
public String sex;
public double height;
public double weight;
public double footSize;
public int getSexCategoryNumber() {
if (sex=="M")
return 1;
return 0;
}
}
private static class PersonEncoder {
CategoryFeatureEncoder sex = new CategoryFeatureEncoder("sex");
ContinuousValueEncoder height = new ContinuousValueEncoder("height");
ContinuousValueEncoder weight = new ContinuousValueEncoder("weight");
ContinuousValueEncoder footSize = new ContinuousValueEncoder("foot-size");
public void addToVector(Person x, Vector data) {
//sex.addToVector(x.getSexCategoryNumber(), data);
height.addToVector((byte[])null, x.height, data);
weight.addToVector((byte[])null, x.weight, data);
footSize.addToVector((byte[])null, x.footSize, data);
}
}
/**
* @param args
*/
public static void main(String[] args) {
try {
ArrayList<Person> males = new ArrayList<Person>();
ArrayList<Person> females = new ArrayList<Person>();
// create the people
Person p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
p1 = new Person();
p1.sex = "M";
p1.height = 6;
p1.weight = 180;
p1.footSize = 12;
males.add(p1);
Person p2 = new Person();
p2.sex = "M";
p2.height = 5.92;
p2.weight = 190;
p2.footSize = 11;
males.add(p2);
Person p3 = new Person();
p3.sex = "M";
p3.height = 5.58;
p3.weight = 170;
p3.footSize = 12;
males.add(p3);
Person p4 = new Person();
p4.sex = "M";
p4.height = 5.92;
p4.weight = 165;
p4.footSize = 10;
males.add(p4);
Person p5 = new Person();
p5.sex = "M";
p5.height = 5;
p5.weight = 100;
p5.footSize = 6;
females.add(p5);
Person p6 = new Person();
p6.sex = "M";
p6.height = 5.5;
p6.weight = 150;
p6.footSize = 8;
females.add(p6);
Person p7 = new Person();
p7.sex = "M";
p7.height = 5.42;
p7.weight = 130;
p7.footSize = 7;
females.add(p7);
Person p8 = new Person();
p8.sex = "M";
p8.height = 5.75;
p8.weight = 120;
p8.footSize = 9;
females.add(p8);
// train with the people
OnlineLogisticRegression model = new OnlineLogisticRegression(2,3, new L2(1));
//AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2,3, new L2(1));
int female = 0;
int male = 1;
for (int i=0; i<1000; i++) {
for(Person p: males) {
Vector personVector = new RandomAccessSparseVector(model.numFeatures());
PersonEncoder p1e = new PersonEncoder();
p1e.addToVector(p, personVector);
// train males
model.train(male, personVector);
}
}
for (int i=0; i<1000; i++) {
for(Person p: females) {
Vector personVector = new RandomAccessSparseVector(model.numFeatures());
PersonEncoder p1e = new PersonEncoder();
p1e.addToVector(p, personVector);
// train males
model.train(female, personVector);
}
}
Person dunno = new Person();
dunno.sex = "";
dunno.height = 5;
dunno.weight = 110;
dunno.footSize = 5;
/*
dunno.height = 8;
dunno.weight = 180;
dunno.footSize = 13;
*/
PersonEncoder dunnoPersonEncoder = new PersonEncoder();
Vector dunnoPersonVector = new RandomAccessSparseVector(model.numFeatures());
dunnoPersonEncoder.addToVector(dunno, dunnoPersonVector);
Vector result = model.classifyFull(dunnoPersonVector);
System.out.print(result.toString());
} catch (Throwable e) {
e.printStackTrace();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment