Created
June 24, 2017 05:56
-
-
Save vigsterkr/1e2eb7452dea67bbf683e03ccf67dfc2 to your computer and use it in GitHub Desktop.
shogunboard - aka monitoring shogun models with tensorboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <shogun/base/init.h> | |
#include <shogun/base/some.h> | |
#include <shogun/labels/MulticlassLabels.h> | |
#include <shogun/lib/SGVector.h> | |
#include <shogun/io/SerializableAsciiFile.h> | |
#include <shogun/machine/gp/SoftMaxLikelihood.h> | |
#include <shogun/kernel/GaussianKernel.h> | |
#include <shogun/classifier/GaussianProcessClassification.h> | |
#include <shogun/mathematics/Math.h> | |
#include <shogun/features/DenseFeatures.h> | |
#include <shogun/io/CSVFile.h> | |
#include <shogun/evaluation/MulticlassAccuracy.h> | |
#include <shogun/lib/WrappedObjectArray.h> | |
#include <shogun/machine/gp/ConstMean.h> | |
#include <shogun/machine/gp/MultiLaplaceInferenceMethod.h> | |
using namespace shogun; | |
int main(int, char*[]) | |
{ | |
init_shogun_with_defaults(); | |
auto f_feats_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_train.dat"); | |
auto f_feats_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_features_test.dat"); | |
auto f_labels_train = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_train.dat"); | |
auto f_labels_test = some<CCSVFile>("../../data/classifier_4class_2d_linear_labels_test.dat"); | |
CMath::init_random(1); | |
auto features_train = some<CDenseFeatures<float64_t>>(f_feats_train); | |
auto features_test = some<CDenseFeatures<float64_t>>(f_feats_test); | |
auto labels_train = some<CMulticlassLabels>(f_labels_train); | |
auto labels_test = some<CMulticlassLabels>(f_labels_test); | |
auto kernel = some<CGaussianKernel>(2.0); | |
auto mean_function = some<CConstMean>(); | |
auto gauss_likelihood = some<CSoftMaxLikelihood>(); | |
auto inference_method = some<CMultiLaplaceInferenceMethod>(kernel, features_train, mean_function, labels_train, gauss_likelihood); | |
auto gp_classifier = some<CGaussianProcessClassification>(inference_method); | |
// option a) | |
// simply create an interface which takes a TFEventWriter | |
// and would dump the subscribed variables into that | |
// | |
// pro: easy to add this feature to SWIG interface | |
// con: puts all the monitoring/even handling into CMachine | |
auto event_writer = some<TFEventWriter>("example-run-001"); | |
gp_classifier->monitor_parameters(event_writer, {"param1", "param2"}); | |
// option b) | |
// CMachine just provides an interface where one could subscribe | |
// to the parameters observable, for example with a class | |
// that implements ParameterObserver (this could write to file or wherever) | |
// pro: | |
// - easy SWIGability | |
// - the CMachine just holds an observable of parameters, but the logic | |
// of what to do with that is outside of machine | |
// con: | |
// - see option c) | |
auto parameter_observer = some<ParameterObserver>({"param1", "param2"}); | |
gp_classifier->subscribe_parameters(parameter_observer); | |
// option c) | |
// CMachine has an Observable just as in case of option b) that emits | |
// the changes of the machine's parameters | |
// con: if we use directly RxCpp's observer, it's almost impossible to use | |
// that outside of the C++ interface, i.e. expose it to SWIG interfaces | |
// pro: you do what you want with the observer :) | |
// | |
// of course option b and c is not mutually exclusive, we could support both. | |
auto parameter_observable = gp_classifier->get_parameters_observable(); | |
// use parameter_observer just as any Observable in the code. | |
gp_classifier->train(); | |
auto labels_predict = gp_classifier->apply_multiclass(features_test); | |
auto evals = some<CMulticlassAccuracy>(); | |
auto accuracy = evals->evaluate(labels_predict, labels_test); | |
exit_shogun(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment