Last active
January 1, 2016 10:19
-
-
Save treper/f407d90c621fa790beb6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#convert mahout data set format to scikit-learn | |
#mahout: -d N 3 C 2 N C 4 N C 8 N 2 C 19 N L | |
#scikit-learn | |
import sys | |
import argparse | |
import numpy | |
from sklearn.cross_validation import cross_val_score | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.ensemble import ExtraTreesClassifier | |
from sklearn.tree import DecisionTreeClassifier | |
from scipy import sparse, array | |
def MyOneHotEncoder(data, keymap=None): | |
""" | |
OneHotEncoder takes data matrix with categorical columns and | |
converts it to a sparse binary matrix. | |
Returns sparse binary matrix and keymap mapping categories to indicies. | |
If a keymap is supplied on input it will be used instead of creating one | |
and any categories appearing in the data that are not in the keymap are | |
ignored | |
""" | |
if keymap is None: | |
keymap = [] | |
for col in data.T: | |
uniques = set(col) | |
print uniques | |
keymap.append(dict((key, i) for i, key in enumerate(uniques))) | |
total_pts = data.shape[0] | |
outdat = [] | |
for i, col in enumerate(data.T): | |
km = keymap[i] | |
num_labels = len(km) | |
spmat = numpy.zeros([total_pts, num_labels]) | |
for j, val in enumerate(col): | |
if val in km: | |
spmat[j, km[val]] = 1 | |
outdat.append(spmat) | |
outdat=numpy.hstack(outdat) | |
return outdat, keymap | |
def LabelEncoder(data, keymap=None): | |
""" | |
OneHotEncoder takes data matrix with categorical columns and | |
converts it to a sparse binary matrix. | |
Returns sparse binary matrix and keymap mapping categories to indicies. | |
If a keymap is supplied on input it will be used instead of creating one | |
and any categories appearing in the data that are not in the keymap are | |
ignored | |
""" | |
if keymap is None: | |
keymap = [] | |
for col in data.T: | |
uniques = set(col) | |
print uniques | |
keymap.append(dict((key, i) for i, key in enumerate(uniques))) | |
total_pts = data.shape[0] | |
outdat = [] | |
for i, col in enumerate(data.T): | |
km = keymap[i] | |
num_labels = len(km) | |
spmat = numpy.zeros([total_pts, 1]) | |
for j, val in enumerate(col): | |
if val in km: | |
spmat[j, 0] = km[val] | |
outdat.append(spmat) | |
outdat=numpy.hstack(outdat) | |
return outdat, keymap | |
def main(): | |
parser = argparse.ArgumentParser(description='Process some integers.') | |
parser.add_argument('-i',dest='input', type=str, help='the mahout input data file path') | |
parser.add_argument('-d','--description', dest ='description',type=str, help='the mahout input data format description') | |
parser.add_argument('-n',dest='num_trees',type=int, help='the number of trees of random forest') | |
parser.add_argument('-c',dest='use_cores',type=int, help='use how many cpu cores') | |
args = parser.parse_args() | |
#print args.input,args.description,args.output | |
#load the mahout data set according to description | |
categories = args.description.split(" ") | |
data_usecols = list() | |
label_usecols = list() | |
category_usecols = list() | |
parsed_categories = list() | |
i = 0 | |
while i< len(categories): | |
try: | |
ct = int(categories[i]) | |
except: | |
ct = -1 | |
if ct >0: | |
j = 0 | |
while j<ct: | |
parsed_categories.append(categories[i+1]) | |
j = j + 1 | |
i = i + 1 | |
else: | |
if categories[i]=='I': | |
pass | |
# elif categories[i]=='L': | |
# label_usecols.append(i) | |
else: | |
parsed_categories.append(categories[i]) | |
i = i + 1 | |
print 'parsed categories:',len(parsed_categories),parsed_categories | |
for i,d in enumerate(parsed_categories): | |
if d=='N': | |
data_usecols.append(i) | |
elif d == 'L': | |
label_usecols.append(i) | |
elif d == 'C': | |
category_usecols.append(i) | |
print 'data_usecols:',data_usecols | |
print 'label_usecols:',label_usecols | |
print 'category_usecols:',category_usecols | |
print 'loading numerical features' | |
data = numpy.loadtxt(open(args.input), delimiter=',', usecols=data_usecols) | |
print 'loading categorical features' | |
cat_features = numpy.genfromtxt(open(args.input), delimiter=',', usecols=category_usecols,dtype='str') | |
print 'converting categorical features to numerical features' | |
Xc,keymap = MyOneHotEncoder(cat_features) | |
#Xc = encoder.transform(numerical_cat_features) | |
print 'generate numericat labels' | |
label_features = numpy.loadtxt(open(args.input), delimiter=',', usecols=label_usecols,dtype='str') | |
label_features=label_features.reshape([label_features.shape[0],1]) | |
#label_features=label_features.reshape([label_features.shape[0],1]) | |
labels,labels_keymap = LabelEncoder(label_features) | |
labels=numpy.ravel(labels) | |
#concate the numerical features and the categorical features | |
print type(data),type(Xc),type(labels) | |
print data.shape,Xc.shape,labels.shape | |
Xf = numpy.append(data,Xc,1) | |
yf = labels | |
#ExtraTreesClassifier | |
clf = ExtraTreesClassifier(n_estimators=args.num_trees, max_depth=None,min_samples_split=1, random_state=0) | |
print 'training ExtraTreesClassifier' | |
scores = cross_val_score(estimator=clf, X=Xf, y=yf,cv=5,n_jobs=args.use_cores,scoring='f1') | |
print "average performance:",scores.mean(),scores | |
#DecisionTreeClassifier | |
clf = DecisionTreeClassifier(max_depth=None, min_samples_split=1,random_state=0) | |
print 'training DecisionTreeClassifier' | |
scores = cross_val_score(estimator=clf, X=Xf, y=yf,cv=5,n_jobs=args.use_cores,scoring='f1') | |
print "average performance:",scores.mean(),scores | |
#RandomForestClassifier | |
clf = RandomForestClassifier(n_estimators=args.num_trees, max_depth=None,min_samples_split=1, random_state=0) | |
print 'training RandomForestClassifier' | |
scores = cross_val_score(estimator=clf, X=Xf, y=yf,cv=5,n_jobs=args.use_cores,scoring='f1') | |
print "average performance:",scores.mean(),scores | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment