Created
July 5, 2011 09:31
-
-
Save CVertex/1064551 to your computer and use it in GitHub Desktop.
My first go at a SGD user
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression; | |
import org.apache.mahout.classifier.sgd.L1; | |
import org.apache.mahout.classifier.sgd.L2; | |
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; | |
import org.apache.mahout.math.RandomAccessSparseVector; | |
import org.apache.mahout.math.Vector; | |
import org.apache.mahout.vectorizer.encoders.TextValueEncoder; | |
import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder; | |
import com.tdunning.ch16.CategoryFeatureEncoder; | |
import com.tdunning.ch16.Item; | |
public class PersonClassifier { | |
private static class Person { | |
public String sex; | |
public double height; | |
public double weight; | |
public double footSize; | |
public int getSexCategoryNumber() { | |
if (sex=="M") | |
return 1; | |
return 0; | |
} | |
} | |
private static class PersonEncoder { | |
CategoryFeatureEncoder sex = new CategoryFeatureEncoder("sex"); | |
ContinuousValueEncoder height = new ContinuousValueEncoder("height"); | |
ContinuousValueEncoder weight = new ContinuousValueEncoder("weight"); | |
ContinuousValueEncoder footSize = new ContinuousValueEncoder("foot-size"); | |
public void addToVector(Person x, Vector data) { | |
//sex.addToVector(x.getSexCategoryNumber(), data); | |
height.addToVector((byte[])null, x.height, data); | |
weight.addToVector((byte[])null, x.weight, data); | |
footSize.addToVector((byte[])null, x.footSize, data); | |
} | |
} | |
/** | |
* @param args | |
*/ | |
public static void main(String[] args) { | |
try { | |
ArrayList<Person> males = new ArrayList<Person>(); | |
ArrayList<Person> females = new ArrayList<Person>(); | |
// create the people | |
Person p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
p1 = new Person(); | |
p1.sex = "M"; | |
p1.height = 6; | |
p1.weight = 180; | |
p1.footSize = 12; | |
males.add(p1); | |
Person p2 = new Person(); | |
p2.sex = "M"; | |
p2.height = 5.92; | |
p2.weight = 190; | |
p2.footSize = 11; | |
males.add(p2); | |
Person p3 = new Person(); | |
p3.sex = "M"; | |
p3.height = 5.58; | |
p3.weight = 170; | |
p3.footSize = 12; | |
males.add(p3); | |
Person p4 = new Person(); | |
p4.sex = "M"; | |
p4.height = 5.92; | |
p4.weight = 165; | |
p4.footSize = 10; | |
males.add(p4); | |
Person p5 = new Person(); | |
p5.sex = "M"; | |
p5.height = 5; | |
p5.weight = 100; | |
p5.footSize = 6; | |
females.add(p5); | |
Person p6 = new Person(); | |
p6.sex = "M"; | |
p6.height = 5.5; | |
p6.weight = 150; | |
p6.footSize = 8; | |
females.add(p6); | |
Person p7 = new Person(); | |
p7.sex = "M"; | |
p7.height = 5.42; | |
p7.weight = 130; | |
p7.footSize = 7; | |
females.add(p7); | |
Person p8 = new Person(); | |
p8.sex = "M"; | |
p8.height = 5.75; | |
p8.weight = 120; | |
p8.footSize = 9; | |
females.add(p8); | |
// train with the people | |
OnlineLogisticRegression model = new OnlineLogisticRegression(2,3, new L2(1)); | |
//AdaptiveLogisticRegression model = new AdaptiveLogisticRegression(2,3, new L2(1)); | |
int female = 0; | |
int male = 1; | |
for (int i=0; i<1000; i++) { | |
for(Person p: males) { | |
Vector personVector = new RandomAccessSparseVector(model.numFeatures()); | |
PersonEncoder p1e = new PersonEncoder(); | |
p1e.addToVector(p, personVector); | |
// train males | |
model.train(male, personVector); | |
} | |
} | |
for (int i=0; i<1000; i++) { | |
for(Person p: females) { | |
Vector personVector = new RandomAccessSparseVector(model.numFeatures()); | |
PersonEncoder p1e = new PersonEncoder(); | |
p1e.addToVector(p, personVector); | |
// train males | |
model.train(female, personVector); | |
} | |
} | |
Person dunno = new Person(); | |
dunno.sex = ""; | |
dunno.height = 5; | |
dunno.weight = 110; | |
dunno.footSize = 5; | |
/* | |
dunno.height = 8; | |
dunno.weight = 180; | |
dunno.footSize = 13; | |
*/ | |
PersonEncoder dunnoPersonEncoder = new PersonEncoder(); | |
Vector dunnoPersonVector = new RandomAccessSparseVector(model.numFeatures()); | |
dunnoPersonEncoder.addToVector(dunno, dunnoPersonVector); | |
Vector result = model.classifyFull(dunnoPersonVector); | |
System.out.print(result.toString()); | |
} catch (Throwable e) { | |
e.printStackTrace(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment