Last active
January 3, 2016 07:49
-
-
Save irwenqiang/8432129 to your computer and use it in GitHub Desktop.
als source code reading
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* iver: | |
* In ALS, we factorize the matrix into two lower rank matrix such that A ~= U*V' | |
* It is not a symmetric matrix factorization where A ~= Q' Q as you write. | |
* | |
* The non zero entries of the matrix A are the edges between user and item nodes. | |
* The edge direction is always between user -> item. | |
* | |
* The user latent vectors are stored in the user vertices (vertex.num_out_edges() > 0). | |
* The item latent vectors are stored in the item vertices (vertex.num_out_edges() == 0) | |
* | |
*/ | |
/** | |
* Copyright (c) 2009 Carnegie Mellon University. | |
* All rights reserved. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, | |
* software distributed under the License is distributed on an "AS | |
* IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | |
* express or implied. See the License for the specific language | |
* governing permissions and limitations under the License. | |
* | |
* For more about this software visit: | |
* | |
* http://www.graphlab.ml.cmu.edu | |
* | |
*/ | |
/** | |
* \file | |
* | |
* \brief The main file for the ALS matrix factorization algorithm. | |
* | |
* This file contains the main body of the ALS matrix factorization | |
* algorithm. | |
*/ | |
#include <Eigen/Dense> | |
#include <boost/config/warning_disable.hpp> | |
#include <boost/spirit/include/qi.hpp> | |
#include <boost/spirit/include/phoenix_core.hpp> | |
#include <boost/spirit/include/phoenix_operator.hpp> | |
#include <boost/spirit/include/phoenix_stl.hpp> | |
// This file defines the serialization code for the eigen types. | |
#include "eigen_serialization.hpp" | |
#include <graphlab.hpp> | |
#include <graphlab/util/stl_util.hpp> | |
#include "stats.hpp" | |
#include <graphlab/macros_def.hpp> | |
const int SAFE_NEG_OFFSET = 2; //add 2 to negative node id | |
//to prevent -0 and -1 which arenot allowed | |
/** | |
* \brief We use the eigen library's vector type to represent | |
* mathematical vectors. | |
*/ | |
typedef Eigen::VectorXd vec_type; | |
/** | |
* \brief We use the eigen library's matrix type to represent | |
* matrices. | |
*/ | |
typedef Eigen::MatrixXd mat_type; | |
/** | |
* \ingroup toolkit_matrix_factorization | |
* | |
* \brief the vertex data type which contains the latent factor. | |
* | |
* Each row and each column in the matrix corresponds to a different | |
* vertex in the ALS graph. Associated with each vertex is a factor | |
* (vector) of latent parameters that represent that vertex. The goal | |
* of the ALS algorithm is to find the values for these latent | |
* parameters such that the non-zero entries in the matrix can be | |
* predicted by taking the dot product of the row and column factors. | |
*/ | |
struct vertex_data { | |
/** | |
* \brief A shared "constant" that specifies the number of latent | |
* values to use. | |
*/ | |
static size_t NLATENT; | |
/** \brief The number of times this vertex has been updated. */ | |
uint32_t nupdates; | |
/** \brief The most recent L1 change in the factor value */ | |
float residual; //! how much the latent value has changed | |
/** \brief The latent factor for this vertex */ | |
vec_type factor; | |
/** | |
* \brief Simple default constructor which randomizes the vertex | |
* data | |
*/ | |
vertex_data() : nupdates(0), residual(1) { randomize(); } | |
/** \brief Randomizes the latent factor */ | |
void randomize() { factor.resize(NLATENT); factor.setRandom(); } | |
/** \brief Save the vertex data to a binary archive */ | |
void save(graphlab::oarchive& arc) const { | |
arc << nupdates << residual << factor; | |
} | |
/** \brief Load the vertex data from a binary archive */ | |
void load(graphlab::iarchive& arc) { | |
arc >> nupdates >> residual >> factor; | |
} | |
}; // end of vertex data | |
size_t vertex_data::NLATENT = 20; | |
/** | |
* \brief The edge data stores the entry in the matrix. | |
* | |
* In addition the edge data also stores the most recent error estimate. | |
*/ | |
struct edge_data : public graphlab::IS_POD_TYPE { | |
/** | |
* \brief The type of data on the edge; | |
* | |
* \li *Train:* the observed value is correct and used in training | |
* \li *Validate:* the observed value is correct but not used in training | |
* \li *Predict:* The observed value is not correct and should not be | |
* used in training. | |
*/ | |
enum data_role_type { TRAIN, VALIDATE, PREDICT }; | |
/** \brief the observed value for the edge */ | |
float obs; | |
/** \brief The train/validation/test designation of the edge */ | |
data_role_type role; | |
/** \brief basic initialization */ | |
edge_data(float obs = 0, data_role_type role = TRAIN) : | |
obs(obs), role(role) { } | |
}; // end of edge data | |
/** | |
* \brief The graph type is defined in terms of the vertex and edge | |
* data. | |
*/ | |
typedef graphlab::distributed_graph<vertex_data, edge_data> graph_type; | |
#include "implicit.hpp" | |
stats_info count_edges(const graph_type::edge_type & edge){ | |
stats_info ret; | |
if (edge.data().role == edge_data::TRAIN) | |
ret.training_edges = 1; | |
else if (edge.data().role == edge_data::VALIDATE) | |
ret.validation_edges = 1; | |
ret.max_user = (size_t)edge.source().id(); | |
ret.max_item = (-edge.target().id()-SAFE_NEG_OFFSET); | |
return ret; | |
} | |
/** | |
* \brief Given a vertex and an edge return the other vertex in the | |
* edge. | |
*/ | |
inline graph_type::vertex_type | |
get_other_vertex(graph_type::edge_type& edge, | |
const graph_type::vertex_type& vertex) { | |
return vertex.id() == edge.source().id()? edge.target() : edge.source(); | |
}; // end of get_other_vertex | |
/** | |
* \brief The gather type used to construct XtX and Xty needed for the ALS | |
* update | |
* | |
* To compute the ALS update we need to compute the sum of | |
* \code | |
* sum: XtX = nbr.factor.transpose() * nbr.factor | |
* sum: Xy = nbr.factor * edge.obs | |
* \endcode | |
* For each of the neighbors of a vertex. | |
* | |
* To do this in the Gather-Apply-Scatter model the gather function | |
* computes and returns a pair consisting of XtX and Xy which are then | |
* added. The gather type represents that tuple and provides the | |
* necessary gather_type::operator+= operation. | |
* | |
*/ | |
class gather_type { | |
public: | |
/** | |
* \brief Stores the current sum of nbr.factor.transpose() * | |
* nbr.factor | |
*/ | |
mat_type XtX; | |
/** | |
* \brief Stores the current sum of nbr.factor * edge.obs | |
*/ | |
vec_type Xy; | |
/** \brief basic default constructor */ | |
gather_type() { } | |
/** | |
* \brief This constructor computes XtX and Xy and stores the result | |
* in XtX and Xy | |
*/ | |
gather_type(const vec_type& X, const double y) : | |
XtX(X.size(), X.size()), Xy(X.size()) { | |
XtX.triangularView<Eigen::Upper>() = X * X.transpose(); | |
Xy = X * y; | |
} // end of constructor for gather type | |
/** \brief Save the values to a binary archive */ | |
void save(graphlab::oarchive& arc) const { arc << XtX << Xy; } | |
/** \brief Read the values from a binary archive */ | |
void load(graphlab::iarchive& arc) { arc >> XtX >> Xy; } | |
/** | |
* \brief Computes XtX += other.XtX and Xy += other.Xy updating this | |
* tuples value | |
*/ | |
gather_type& operator+=(const gather_type& other) { | |
if(other.Xy.size() == 0) { | |
ASSERT_EQ(other.XtX.rows(), 0); | |
ASSERT_EQ(other.XtX.cols(), 0); | |
} else { | |
if(Xy.size() == 0) { | |
ASSERT_EQ(XtX.rows(), 0); | |
ASSERT_EQ(XtX.cols(), 0); | |
XtX = other.XtX; Xy = other.Xy; | |
} else { | |
XtX.triangularView<Eigen::Upper>() += other.XtX; | |
Xy += other.Xy; | |
} | |
} | |
return *this; | |
} // end of operator+= | |
}; // end of gather type | |
/** | |
* \brief ALS vertex program implements the alternating least squares | |
* algorithm in the Gather-Apply-Scatter abstraction. | |
* | |
* The ALS update treats adjacent vertices (rows or columns) as "X" | |
* (independent) values and the edges (matrix entries) as observed "y" | |
* (dependent) values and then updates the current vertex value as a | |
* weight "w" such that: | |
* | |
* y = X * w + noise | |
* | |
* This is accomplished using the following equation: | |
* | |
* w = inv(X' * X) * (X' * y) | |
* | |
* We implement this in the Gather-Apply-Scatter model by: | |
* | |
* 1) Gather: returns the tuple (X' * X, X' * y) | |
* Sum: (aX' * aX, aX * ay) + (bX' * bX, bX * by) = | |
* (aX' * aX + bX' * bX, aX * ay + bX * by) | |
* | |
* 2) Apply: Solves inv(X' * X) * (X' * y) | |
* | |
* 3) Scatter: schedules the update of adjacent vertices if this | |
* vertex has changed sufficiently and the edge is not well | |
* predicted. | |
* | |
* | |
*/ | |
class als_vertex_program : | |
public graphlab::ivertex_program<graph_type, gather_type, | |
graphlab::messages::sum_priority>, | |
public graphlab::IS_POD_TYPE { | |
public: | |
/** The convergence tolerance */ | |
static double TOLERANCE; | |
static double LAMBDA; | |
static size_t MAX_UPDATES; | |
static double MAXVAL; | |
static double MINVAL; | |
static int REGNORMAL; //regularization type | |
/** The set of edges to gather along */ | |
edge_dir_type gather_edges(icontext_type& context, | |
const vertex_type& vertex) const { | |
return graphlab::ALL_EDGES; | |
}; // end of gather_edges | |
/** The gather function computes XtX and Xy */ | |
gather_type gather(icontext_type& context, const vertex_type& vertex, | |
edge_type& edge) const { | |
if(edge.data().role == edge_data::TRAIN) { | |
const vertex_type other_vertex = get_other_vertex(edge, vertex); | |
return gather_type(other_vertex.data().factor, edge.data().obs); | |
} else return gather_type(); | |
} // end of gather function | |
/** apply collects the sum of XtX and Xy */ | |
void apply(icontext_type& context, vertex_type& vertex, | |
const gather_type& sum) { | |
// Get and reset the vertex data | |
vertex_data& vdata = vertex.data(); | |
// Determine the number of neighbors. Each vertex has only in or | |
// out edges depending on which side of the graph it is located | |
if(sum.Xy.size() == 0) { vdata.residual = 0; ++vdata.nupdates; return; } | |
mat_type XtX = sum.XtX; | |
vec_type Xy = sum.Xy; | |
// Add regularization | |
double regularization = LAMBDA; | |
if (REGNORMAL) | |
regularization = LAMBDA*vertex.num_out_edges(); | |
for(int i = 0; i < XtX.rows(); ++i) | |
XtX(i,i) += regularization; | |
// Solve the least squares problem using eigen ---------------------------- | |
const vec_type old_factor = vdata.factor; | |
vdata.factor = XtX.selfadjointView<Eigen::Upper>().ldlt().solve(Xy); | |
// Compute the residual change in the factor factor ----------------------- | |
vdata.residual = (vdata.factor - old_factor).cwiseAbs().sum() / XtX.rows(); | |
++vdata.nupdates; | |
} // end of apply | |
/** The edges to scatter along */ | |
edge_dir_type scatter_edges(icontext_type& context, | |
const vertex_type& vertex) const { | |
return graphlab::ALL_EDGES; | |
}; // end of scatter edges | |
/** Scatter reschedules neighbors */ | |
void scatter(icontext_type& context, const vertex_type& vertex, | |
edge_type& edge) const { | |
edge_data& edata = edge.data(); | |
if(edata.role == edge_data::TRAIN) { | |
const vertex_type other_vertex = get_other_vertex(edge, vertex); | |
const vertex_data& vdata = vertex.data(); | |
const vertex_data& other_vdata = other_vertex.data(); | |
const double pred = vdata.factor.dot(other_vdata.factor); | |
const float error = std::fabs(edata.obs - pred); | |
const double priority = (error * vdata.residual); | |
// Reschedule neighbors ------------------------------------------------ | |
if( priority > TOLERANCE && other_vdata.nupdates < MAX_UPDATES) | |
/** | |
* iver: the alternative optimization for U with V fixed, and vice-versa. | |
*/ | |
context.signal(other_vertex, priority); | |
} | |
} // end of scatter function | |
/** | |
* \brief Signal all vertices on one side of the bipartite graph | |
*/ | |
static graphlab::empty signal_left(icontext_type& context, | |
const vertex_type& vertex) { | |
if(vertex.num_out_edges() > 0) context.signal(vertex); | |
return graphlab::empty(); | |
} // end of signal_left | |
}; // end of als vertex program | |
/** | |
* \brief The graph loader function is a line parser used for | |
* distributed graph construction. | |
*/ | |
inline bool graph_loader(graph_type& graph, | |
const std::string& filename, | |
const std::string& line) { | |
ASSERT_FALSE(line.empty()); | |
namespace qi = boost::spirit::qi; | |
namespace ascii = boost::spirit::ascii; | |
namespace phoenix = boost::phoenix; | |
// Determine the role of the data | |
edge_data::data_role_type role = edge_data::TRAIN; | |
if(boost::ends_with(filename,".validate")) role = edge_data::VALIDATE; | |
else if(boost::ends_with(filename, ".predict")) role = edge_data::PREDICT; | |
// Parse the line | |
graph_type::vertex_id_type source_id(-1), target_id(-1); | |
float obs(0); | |
const bool success = qi::phrase_parse | |
(line.begin(), line.end(), | |
// Begin grammar | |
( | |
qi::ulong_[phoenix::ref(source_id) = qi::_1] >> -qi::char_(',') >> | |
qi::ulong_[phoenix::ref(target_id) = qi::_1] >> | |
-(-qi::char_(',') >> qi::float_[phoenix::ref(obs) = qi::_1]) | |
) | |
, | |
// End grammar | |
ascii::space); | |
if(!success) return false; | |
if(role == edge_data::TRAIN || role == edge_data::VALIDATE){ | |
if (obs < als_vertex_program::MINVAL || obs > als_vertex_program::MAXVAL) | |
logstream(LOG_FATAL)<<"Rating values should be between " << als_vertex_program::MINVAL << " and " << als_vertex_program::MAXVAL << ". Got value: " << obs << " [ user: " << source_id << " to item: " <<target_id << " ] " << std::endl; | |
} | |
// map target id into a separate number space | |
target_id = -(graphlab::vertex_id_type(target_id + SAFE_NEG_OFFSET)); | |
// Create an edge and add it to the graph | |
graph.add_edge(source_id, target_id, edge_data(obs, role)); | |
return true; // successful load | |
} // end of graph_loader | |
/** | |
* \brief Given an edge compute the error associated with that edge | |
*/ | |
double extract_l2_error(const graph_type::edge_type & edge) { | |
double pred = | |
edge.source().data().factor.dot(edge.target().data().factor); | |
pred = std::min(als_vertex_program::MAXVAL, pred); | |
pred = std::max(als_vertex_program::MINVAL, pred); | |
return (edge.data().obs - pred) * (edge.data().obs - pred); | |
} // end of extract_l2_error | |
double als_vertex_program::TOLERANCE = 1e-3; | |
double als_vertex_program::LAMBDA = 0.01; | |
size_t als_vertex_program::MAX_UPDATES = -1; | |
double als_vertex_program::MAXVAL = 1e+100; | |
double als_vertex_program::MINVAL = -1e+100; | |
int als_vertex_program::REGNORMAL = 1; | |
/** | |
* \brief The error aggregator is used to accumulate the overal | |
* prediction error. | |
* | |
* The error aggregator is itself a "reduction type" and contains the | |
* two static methods "map" and "finalize" which operate on | |
* error_aggregators and are used by the engine.add_edge_aggregator | |
* api. | |
*/ | |
struct error_aggregator : public graphlab::IS_POD_TYPE { | |
typedef als_vertex_program::icontext_type icontext_type; | |
typedef graph_type::edge_type edge_type; | |
double train_error, validation_error; | |
error_aggregator() : | |
train_error(0), validation_error(0) { } | |
error_aggregator& operator+=(const error_aggregator& other) { | |
train_error += other.train_error; | |
validation_error += other.validation_error; | |
return *this; | |
} | |
static error_aggregator map(icontext_type& context, const graph_type::edge_type& edge) { | |
error_aggregator agg; | |
if(edge.data().role == edge_data::TRAIN) { | |
agg.train_error = extract_l2_error(edge); | |
} else if(edge.data().role == edge_data::VALIDATE) { | |
agg.validation_error = extract_l2_error(edge); | |
} | |
return agg; | |
} | |
static void finalize(icontext_type& context, const error_aggregator& agg) { | |
const double train_error = std::sqrt(agg.train_error / info.training_edges); | |
context.cout() << "Time in seconds: " << context.elapsed_seconds() << "\tiTraining RMSE: " << train_error; | |
if(info.validation_edges > 0) { | |
const double validation_error = | |
std::sqrt(agg.validation_error / info.validation_edges); | |
context.cout() << "\tValidation RMSE: " << validation_error; | |
} | |
context.cout() << std::endl; | |
} | |
}; // end of error aggregator | |
/** | |
* \brief The prediction saver is used by the graph.save routine to | |
* output the final predictions back to the filesystem. | |
*/ | |
struct prediction_saver { | |
typedef graph_type::vertex_type vertex_type; | |
typedef graph_type::edge_type edge_type; | |
std::string save_vertex(const vertex_type& vertex) const { | |
return ""; //nop | |
} | |
std::string save_edge(const edge_type& edge) const { | |
if(edge.data().role == edge_data::PREDICT) { | |
std::stringstream strm; | |
const double prediction = | |
edge.source().data().factor.dot(edge.target().data().factor); | |
strm << edge.source().id() << '\t'; | |
strm << (-edge.target().id() - SAFE_NEG_OFFSET) << '\t'; | |
strm << prediction << '\n'; | |
return strm.str(); | |
} else return ""; | |
} | |
}; // end of prediction_saver | |
struct linear_model_saver_U { | |
typedef graph_type::vertex_type vertex_type; | |
typedef graph_type::edge_type edge_type; | |
/* save the linear model, using the format: | |
nodeid factor1 factor2 ... factorNLATENT \n | |
*/ | |
std::string save_vertex(const vertex_type& vertex) const { | |
if (vertex.num_out_edges() > 0){ | |
std::string ret = boost::lexical_cast<std::string>(vertex.id()) + " "; | |
for (uint i=0; i< vertex_data::NLATENT; i++) | |
ret += boost::lexical_cast<std::string>(vertex.data().factor[i]) + " "; | |
ret += "\n"; | |
return ret; | |
} | |
else return ""; | |
} | |
std::string save_edge(const edge_type& edge) const { | |
return ""; | |
} | |
}; | |
struct linear_model_saver_V { | |
typedef graph_type::vertex_type vertex_type; | |
typedef graph_type::edge_type edge_type; | |
/* save the linear model, using the format: | |
nodeid factor1 factor2 ... factorNLATENT \n | |
*/ | |
std::string save_vertex(const vertex_type& vertex) const { | |
if (vertex.num_out_edges() == 0){ | |
std::string ret = boost::lexical_cast<std::string>(-vertex.id()-SAFE_NEG_OFFSET) + " "; | |
for (uint i=0; i< vertex_data::NLATENT; i++) | |
ret += boost::lexical_cast<std::string>(vertex.data().factor[i]) + " "; | |
ret += "\n"; | |
return ret; | |
} | |
else return ""; | |
} | |
std::string save_edge(const edge_type& edge) const { | |
return ""; | |
} | |
}; | |
/** | |
* \brief The engine type used by the ALS matrix factorization | |
* algorithm. | |
* | |
* The ALS matrix factorization algorithm currently uses the | |
* synchronous engine. However we plan to add support for alternative | |
* engines in the future. | |
*/ | |
typedef graphlab::omni_engine<als_vertex_program> engine_type; | |
int main(int argc, char** argv) { | |
global_logger().set_log_level(LOG_INFO); | |
global_logger().set_log_to_console(true); | |
// Parse command line options ----------------------------------------------- | |
const std::string description = | |
"Compute the ALS factorization of a matrix."; | |
graphlab::command_line_options clopts(description); | |
std::string input_dir; | |
std::string predictions; | |
size_t interval = 10; | |
std::string exec_type = "synchronous"; | |
clopts.attach_option("matrix", input_dir, | |
"The directory containing the matrix file"); | |
clopts.add_positional("matrix"); | |
clopts.attach_option("D", vertex_data::NLATENT, | |
"Number of latent parameters to use."); | |
clopts.attach_option("max_iter", als_vertex_program::MAX_UPDATES, | |
"The maxumum number of udpates allowed for a vertex"); | |
clopts.attach_option("lambda", als_vertex_program::LAMBDA, | |
"ALS regularization weight"); | |
clopts.attach_option("tol", als_vertex_program::TOLERANCE, | |
"residual termination threshold"); | |
clopts.attach_option("maxval", als_vertex_program::MAXVAL, "max allowed value"); | |
clopts.attach_option("minval", als_vertex_program::MINVAL, "min allowed value"); | |
clopts.attach_option("interval", interval, | |
"The time in seconds between error reports"); | |
clopts.attach_option("predictions", predictions, | |
"The prefix (folder and filename) to save predictions."); | |
clopts.attach_option("engine", exec_type, | |
"The engine type synchronous or asynchronous"); | |
clopts.attach_option("regnormal", als_vertex_program::REGNORMAL, | |
"regularization type. 1 = weighted according to neighbors num. 0 = no weighting - just lambda"); | |
parse_implicit_command_line(clopts); | |
if(!clopts.parse(argc, argv) || input_dir == "") { | |
std::cout << "Error in parsing command line arguments." << std::endl; | |
clopts.print_description(); | |
return EXIT_FAILURE; | |
} | |
///! Initialize control plain using mpi | |
graphlab::mpi_tools::init(argc, argv); | |
graphlab::distributed_control dc; | |
dc.cout() << "Loading graph." << std::endl; | |
graphlab::timer timer; | |
graph_type graph(dc, clopts); | |
graph.load(input_dir, graph_loader); | |
dc.cout() << "Loading graph. Finished in " | |
<< timer.current_time() << std::endl; | |
if (dc.procid() == 0) | |
add_implicit_edges<edge_data>(implicitratingtype, graph, dc); | |
dc.cout() << "Finalizing graph." << std::endl; | |
timer.start(); | |
graph.finalize(); | |
dc.cout() << "Finalizing graph. Finished in " | |
<< timer.current_time() << std::endl; | |
dc.cout() | |
<< "========== Graph statistics on proc " << dc.procid() | |
<< " ===============" | |
<< "\n Num vertices: " << graph.num_vertices() | |
<< "\n Num edges: " << graph.num_edges() | |
<< "\n Num replica: " << graph.num_replicas() | |
<< "\n Replica to vertex ratio: " | |
<< float(graph.num_replicas())/graph.num_vertices() | |
<< "\n --------------------------------------------" | |
<< "\n Num local own vertices: " << graph.num_local_own_vertices() | |
<< "\n Num local vertices: " << graph.num_local_vertices() | |
<< "\n Replica to own ratio: " | |
<< (float)graph.num_local_vertices()/graph.num_local_own_vertices() | |
<< "\n Num local edges: " << graph.num_local_edges() | |
//<< "\n Begin edge id: " << graph.global_eid(0) | |
<< "\n Edge balance ratio: " | |
<< float(graph.num_local_edges())/graph.num_edges() | |
<< std::endl; | |
dc.cout() << "Creating engine" << std::endl; | |
engine_type engine(dc, graph, exec_type, clopts); | |
// Add error reporting to the engine | |
const bool success = engine.add_edge_aggregator<error_aggregator> | |
("error", error_aggregator::map, error_aggregator::finalize) && | |
engine.aggregate_periodic("error", interval); | |
ASSERT_TRUE(success); | |
// Signal all vertices on the vertices on the left (liberals) | |
engine.map_reduce_vertices<graphlab::empty>(als_vertex_program::signal_left); | |
info = graph.map_reduce_edges<stats_info>(count_edges); | |
dc.cout()<<"Training edges: " << info.training_edges << " validation edges: " << info.validation_edges << std::endl; | |
// Run ALS --------------------------------------------------------- | |
dc.cout() << "Running ALS" << std::endl; | |
timer.start(); | |
engine.start(); | |
const double runtime = timer.current_time(); | |
dc.cout() << "----------------------------------------------------------" | |
<< std::endl | |
<< "Final Runtime (seconds): " << runtime | |
<< std::endl | |
<< "Updates executed: " << engine.num_updates() << std::endl | |
<< "Update Rate (updates/second): " | |
<< engine.num_updates() / runtime << std::endl; | |
// Compute the final training error ----------------------------------------- | |
dc.cout() << "Final error: " << std::endl; | |
engine.aggregate_now("error"); | |
// Make predictions --------------------------------------------------------- | |
if(!predictions.empty()) { | |
std::cout << "Saving predictions" << std::endl; | |
const bool gzip_output = false; | |
const bool save_vertices = false; | |
const bool save_edges = true; | |
const size_t threads_per_machine = 2; | |
//save the predictions | |
graph.save(predictions, prediction_saver(), | |
gzip_output, save_vertices, | |
save_edges, threads_per_machine); | |
//save the linear model | |
graph.save(predictions + ".U", linear_model_saver_U(), | |
gzip_output, save_edges, save_vertices, threads_per_machine); | |
graph.save(predictions + ".V", linear_model_saver_V(), | |
gzip_output, save_edges, save_vertices, threads_per_machine); | |
} | |
graphlab::mpi_tools::finalize(); | |
return EXIT_SUCCESS; | |
} // end of main |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment