Last active
August 29, 2015 14:02
-
-
Save suma/2dfac4356d8bd1bb14ac to your computer and use it in GitHub Desktop.
linear_mixer test for nearest_neighbor (jubatus_core).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/jubatus/core/driver/nearest_neighbor_test.cpp b/jubatus/core/driver/nearest_neighbor_test.cpp | |
index 554ed51..00f00bf 100644 | |
--- a/jubatus/core/driver/nearest_neighbor_test.cpp | |
+++ b/jubatus/core/driver/nearest_neighbor_test.cpp | |
@@ -145,10 +145,14 @@ class nearest_neighbor_test | |
: public ::testing::TestWithParam< | |
shared_ptr<core::nearest_neighbor::nearest_neighbor_base> > { | |
protected: | |
- void SetUp() { | |
- nn_driver_ = shared_ptr<core::driver::nearest_neighbor>( | |
+ shared_ptr<core::driver::nearest_neighbor> create_driver() const { | |
+ return shared_ptr<core::driver::nearest_neighbor>( | |
new core::driver::nearest_neighbor(GetParam(), make_fv_converter())); | |
} | |
+ | |
+ void SetUp() { | |
+ nn_driver_ = create_driver(); | |
+ } | |
void TearDown() { | |
nn_driver_->clear(); | |
} | |
@@ -300,6 +304,47 @@ TEST_P(nearest_neighbor_test, small) { | |
nn_driver_->neighbor_row_from_data(create_datum_2d(1.f, 1.f), 2); | |
} | |
+TEST_P(nearest_neighbor_test, small_mix) { | |
+ framework::linear_mixable* nn_mixable = | |
+ dynamic_cast<framework::linear_mixable*>(nn_driver_->get_mixable()); | |
+ shared_ptr<driver::nearest_neighbor> other = create_driver(); | |
+ framework::linear_mixable* other_mixable = | |
+ dynamic_cast<framework::linear_mixable*>(other->get_mixable()); | |
+ ASSERT_TRUE(nn_mixable); | |
+ ASSERT_TRUE(other_mixable); | |
+ | |
+ nn_driver_->set_row("a", single_str_datum("x", "hoge")); | |
+ nn_driver_->set_row("b", single_str_datum("y", "fuga")); | |
+ | |
+ msgpack::sbuffer data; | |
+ { | |
+ core::framework::stream_writer<msgpack::sbuffer> st(data); | |
+ core::framework::jubatus_packer jp(st); | |
+ core::framework::packer pk(jp); | |
+ nn_mixable->get_diff(pk); | |
+ } | |
+ { | |
+ msgpack::sbuffer sbuf; | |
+ core::framework::stream_writer<msgpack::sbuffer> st(sbuf); | |
+ core::framework::jubatus_packer jp(st); | |
+ core::framework::packer pk(jp); | |
+ other_mixable->get_diff(pk); | |
+ | |
+ msgpack::unpacked msg; | |
+ msgpack::unpack(&msg, sbuf.data(), sbuf.size()); | |
+ std::cout << msg.get() << std::endl; | |
+ framework::diff_object diff = other_mixable->convert_diff_object(msg.get()); | |
+ std::cout << "hello " << msg.get() << std::endl; | |
+ | |
+ msgpack::unpacked data_msg; | |
+ msgpack::unpack(&data_msg, data.data(), data.size()); | |
+ | |
+ other_mixable->mix(data_msg.get(), diff); | |
+ other_mixable->put_diff(diff); | |
+ } | |
+} | |
+ | |
+ | |
INSTANTIATE_TEST_CASE_P(nearest_neighbor_test_instance, | |
nearest_neighbor_test, | |
testing::ValuesIn(create_nearest_neighbor_bases())); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment