Skip to content

Instantly share code, notes, and snippets.

@Drunkar
Created February 28, 2016 05:56
Show Gist options
  • Save Drunkar/c4d3d7328899020f4662 to your computer and use it in GitHub Desktop.
Save Drunkar/c4d3d7328899020f4662 to your computer and use it in GitHub Desktop.
# coding: utf-8
import argparse
import numpy as np
from sklearn import svm
from sklearn.multiclass import OneVsRestClassifier
from sklearn.externals import joblib
parser = argparse.ArgumentParser(description="svm")
parser.add_argument("-pretrained_model", default=None)
parser.add_argument("-train", default=None, help="data for train")
parser.add_argument("-test", default="test.csv", help="data for test")
def main(args):
if args.pretrained_model is None:
estimator = svm.SVC(kernel="linear", probability=True, class_weight="auto")
else:
estimator = joblib.load(args.pretrained_model)
classifier = OneVsRestClassifier(estimator)
# train
if args.train is not None:
train_x = []
train_y = []
with open(args.train) as fi:
for i, line in enumerate(fi):
if i == 0: continue
row = line.rstrip().split(",")
train_x.append(row[1:])
train_y.append(row[0])
classifier.fit(train_x, train_y)
# save
joblib.dump(classifier, "svm.pkl")
if __name__ == '__main__':
args = parser.parse_args()
main(args)P
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment