Created
August 7, 2012 14:18
-
-
Save delta2323/3285737 to your computer and use it in GitHub Desktop.
Test Code of Classifier Client (Python)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import unittest | |
from classifier.client import classifier | |
from classifier.types import * | |
# Please fill the blank according to your own setting. | |
BUILDDIR="" | |
def fork_process(port): | |
import os | |
import time | |
try: | |
pid = os.fork() | |
if pid < 0: | |
print 'fork error' | |
sys.exit(1) | |
elif pid == 0: | |
path = BUILDDIR + "/bin/jubaclassifier" | |
os.execv(path, [path, "-p", str(port), "-c", "100"]) | |
time.sleep(1.0) | |
return pid | |
except OSError as error: | |
print 'Unable to fork. Error: %d (%s)' % (error.errno, error.strerror) | |
sys.exit(1) | |
def kill_process(pid): | |
import signal | |
import os | |
if os.kill(pid, signal.SIGTERM) != None: | |
print 'kill error' | |
os.waitpid(pid, 0) | |
class ClassifierTest(unittest.TestCase): | |
def setUp(self): | |
port = 5000 | |
self.pid = fork_process(port) | |
self.client = classifier("localhost", port) | |
method = "AROW" | |
self.converter = ''' | |
{ | |
\"string_filter_types\":{}, | |
\"string_filter_rules\":[], | |
\"num_filter_types\":{}, | |
\"num_filter_rules\":[], | |
\"string_types\":{}, | |
\"string_rules\":[ | |
{ | |
\"key\":\"*\", \"type\":\"space\", | |
\"sample_weight\":\"bin\", \"global_weight\":\"bin\" | |
} | |
], | |
\"num_types\":{}, | |
\"num_rules\":[ | |
{\"key\":\"*\", \"type\":\"num\"} | |
] | |
} | |
''' | |
cd = config_data(method, self.converter) | |
self.client.set_config("name", cd) | |
def tearDown(self): | |
kill_process(self.pid) | |
def test_get_config(self): | |
config = self.client.get_config("name") | |
self.assertEqual(config.method, "AROW") | |
self.assertEqual(config.config, self.converter) | |
def test_train_and_classify(self): | |
string_values = [] | |
for i in xrange(10): | |
key = "key/str" + str(i) | |
val = "val/str" + str(i) | |
string_values.append([key, val]) | |
num_values = [] | |
for i in xrange(10): | |
key = "key/str" + str(i) | |
val = float(i) | |
num_values.append([key, val]) | |
train_data = [("label", datum(string_values, num_values))] | |
for i in xrange(100): | |
self.assertEqual(True, self.client.train("name", train_data)) | |
test_data = [datum(string_values, num_values)] | |
result = self.client.classify("name", test_data) | |
self.assertEqual(1, len(result)) | |
self.assertEqual(1, len(result[0])) | |
self.assertEqual("label", result[0][0].label) | |
self.assertAlmostEqual(1.0, result[0][0].prob) | |
def test_save_and_load(self): | |
import datetime | |
t = datetime.datetime.today().strftime("%Y%m%d-%H%M%S") | |
model_name = "classifier.model." + t | |
self.client.save("name", model_name) | |
self.assertEqual(True, self.client.load("name", model_name)) | |
def test_get_status(self): | |
s = self.client.get_status("name") | |
if __name__ == '__main__': | |
test_suite = unittest.TestLoader().loadTestsFromTestCase(ClassifierTest) | |
unittest.TextTestRunner().run(test_suite) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment