Last active
February 12, 2023 14:46
-
-
Save rueycheng/95d5c6b9864779b9395730540d82d39d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h | |
index 12dbe6c..ef058af 100644 | |
--- a/include/LightGBM/dataset.h | |
+++ b/include/LightGBM/dataset.h | |
@@ -88,6 +88,8 @@ class Metadata { | |
void SetLabel(const label_t* label, data_size_t len); | |
+ void SetOrdering(const label_t* ordering, data_size_t len); | |
+ | |
void SetWeights(const label_t* weights, data_size_t len); | |
void SetQuery(const data_size_t* query, data_size_t len); | |
@@ -143,6 +145,18 @@ class Metadata { | |
queries_[idx] = static_cast<data_size_t>(value); | |
} | |
+ /*! | |
+ * \brief Get ordering, if not exists, will return nullptr | |
+ * \return Pointer of ordering | |
+ */ | |
+ inline const label_t* ordering() const { | |
+ if (!ordering_.empty()) { | |
+ return ordering_.data(); | |
+ } else { | |
+ return nullptr; | |
+ } | |
+ } | |
+ | |
/*! | |
* \brief Get weights, if not exists, will return nullptr | |
* \return Pointer of weights | |
@@ -213,6 +227,8 @@ class Metadata { | |
private: | |
/*! \brief Load initial scores from file */ | |
void LoadInitialScore(const char* initscore_file); | |
+ /*! \brief Load ordering from file */ | |
+ void LoadOrdering(); | |
/*! \brief Load wights from file */ | |
void LoadWeights(); | |
/*! \brief Load query boundaries from file */ | |
@@ -223,10 +239,14 @@ class Metadata { | |
std::string data_filename_; | |
/*! \brief Number of data */ | |
data_size_t num_data_; | |
+ /*! \brief Number of ordering, used to check correct weight file */ | |
+ data_size_t num_ordering_; | |
/*! \brief Number of weights, used to check correct weight file */ | |
data_size_t num_weights_; | |
/*! \brief Label data */ | |
std::vector<label_t> label_; | |
+ /*! \brief Ordering data */ | |
+ std::vector<label_t> ordering_; | |
/*! \brief Weights data */ | |
std::vector<label_t> weights_; | |
/*! \brief Query boundaries */ | |
@@ -243,6 +263,7 @@ class Metadata { | |
std::vector<data_size_t> queries_; | |
/*! \brief mutex for threading safe call */ | |
std::mutex mutex_; | |
+ bool ordering_load_from_file_; | |
bool weight_load_from_file_; | |
bool query_load_from_file_; | |
bool init_score_load_from_file_; | |
diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp | |
index ccb72b9..e095997 100644 | |
--- a/src/io/dataset.cpp | |
+++ b/src/io/dataset.cpp | |
@@ -516,6 +516,12 @@ bool Dataset::SetFloatField(const char* field_name, const float* field_data, dat | |
#else | |
metadata_.SetLabel(field_data, num_element); | |
#endif | |
+ } else if (name == std::string("ordering")) { | |
+ #ifdef LABEL_T_USE_DOUBLE | |
+ Log::Fatal("Don't support LABEL_T_USE_DOUBLE"); | |
+ #else | |
+ metadata_.SetOrdering(field_data, num_element); | |
+ #endif | |
} else if (name == std::string("weight") || name == std::string("weights")) { | |
#ifdef LABEL_T_USE_DOUBLE | |
Log::Fatal("Don't support LABEL_T_USE_DOUBLE"); | |
@@ -560,6 +566,13 @@ bool Dataset::GetFloatField(const char* field_name, data_size_t* out_len, const | |
*out_ptr = metadata_.label(); | |
*out_len = num_data_; | |
#endif | |
+ } else if (name == std::string("ordering")) { | |
+ #ifdef LABEL_T_USE_DOUBLE | |
+ Log::Fatal("Don't support LABEL_T_USE_DOUBLE"); | |
+ #else | |
+ *out_ptr = metadata_.ordering(); | |
+ *out_len = num_data_; | |
+ #endif | |
} else if (name == std::string("weight") || name == std::string("weights")) { | |
#ifdef LABEL_T_USE_DOUBLE | |
Log::Fatal("Don't support LABEL_T_USE_DOUBLE"); | |
diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp | |
index a73ec45..b1a8cc8 100644 | |
--- a/src/io/metadata.cpp | |
+++ b/src/io/metadata.cpp | |
@@ -11,10 +11,12 @@ | |
namespace LightGBM { | |
Metadata::Metadata() { | |
+ num_ordering_ = 0; | |
num_weights_ = 0; | |
num_init_score_ = 0; | |
num_data_ = 0; | |
num_queries_ = 0; | |
+ ordering_load_from_file_ = false; | |
weight_load_from_file_ = false; | |
query_load_from_file_ = false; | |
init_score_load_from_file_ = false; | |
@@ -24,6 +26,7 @@ void Metadata::Init(const char * data_filename, const char* initscore_file) { | |
data_filename_ = data_filename; | |
// for lambdarank, it needs query data for partition data in parallel learning | |
LoadQueryBoundaries(); | |
+ LoadOrdering(); | |
LoadWeights(); | |
LoadQueryWeights(); | |
LoadInitialScore(initscore_file); | |
@@ -72,6 +75,17 @@ void Metadata::Init(const Metadata& fullset, const data_size_t* used_indices, da | |
label_[i] = fullset.label_[used_indices[i]]; | |
} | |
+ if (!fullset.ordering_.empty()) { | |
+ ordering_ = std::vector<label_t>(num_used_indices); | |
+ num_ordering_ = num_used_indices; | |
+#pragma omp parallel for schedule(static) | |
+ for (data_size_t i = 0; i < num_used_indices; i++) { | |
+ ordering_[i] = fullset.ordering_[used_indices[i]]; | |
+ } | |
+ } else { | |
+ num_ordering_ = 0; | |
+ } | |
+ | |
if (!fullset.weights_.empty()) { | |
weights_ = std::vector<label_t>(num_used_indices); | |
num_weights_ = num_used_indices; | |
@@ -171,6 +185,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data | |
LoadQueryWeights(); | |
queries_.clear(); | |
} | |
+ // check ordering | |
+ if (!ordering_.empty() && num_ordering_ != num_data_) { | |
+ ordering_.clear(); | |
+ num_ordering_ = 0; | |
+ Log::Fatal("Ordering size doesn't match data size"); | |
+ } | |
+ | |
// check weights | |
if (!weights_.empty() && num_weights_ != num_data_) { | |
weights_.clear(); | |
@@ -196,6 +217,25 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data | |
Log::Fatal("Cannot used query_id for parallel training"); | |
} | |
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); | |
+ // check ordering | |
+ if (ordering_load_from_file_) { | |
+ if (ordering_.size() > 0 && num_ordering_ != num_all_data) { | |
+ ordering_.clear(); | |
+ num_ordering_ = 0; | |
+ Log::Fatal("ordering size doesn't match data size"); | |
+ } | |
+ // get local ordering | |
+ if (!ordering_.empty()) { | |
+ auto old_ordering = ordering_; | |
+ num_ordering_ = num_data_; | |
+ ordering_ = std::vector<label_t>(num_data_); | |
+#pragma omp parallel for schedule(static) | |
+ for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) { | |
+ ordering_[i] = old_ordering[used_data_indices[i]]; | |
+ } | |
+ old_ordering.clear(); | |
+ } | |
+ } | |
// check weights | |
if (weight_load_from_file_) { | |
if (weights_.size() > 0 && num_weights_ != num_all_data) { | |
@@ -320,6 +360,28 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) { | |
} | |
} | |
+void Metadata::SetOrdering(const label_t* ordering, data_size_t len) { | |
+ std::lock_guard<std::mutex> lock(mutex_); | |
+ // save to nullptr | |
+ if (ordering == nullptr || len == 0) { | |
+ ordering_.clear(); | |
+ num_ordering_ = 0; | |
+ return; | |
+ } | |
+ if (num_data_ != len) { | |
+ Log::Fatal("Length of ordering is not same with #data"); | |
+ } | |
+ if (!ordering_.empty()) { ordering_.clear(); } | |
+ num_ordering_ = num_data_; | |
+ ordering_ = std::vector<label_t>(num_ordering_); | |
+#pragma omp parallel for schedule(static) | |
+ for (data_size_t i = 0; i < num_ordering_; ++i) { | |
+ ordering_[i] = ordering[i]; | |
+ } | |
+ LoadOrdering(); | |
+ ordering_load_from_file_ = false; | |
+} | |
+ | |
void Metadata::SetWeights(const label_t* weights, data_size_t len) { | |
std::lock_guard<std::mutex> lock(mutex_); | |
// save to nullptr | |
@@ -369,6 +431,28 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { | |
query_load_from_file_ = false; | |
} | |
+void Metadata::LoadOrdering() { | |
+ num_ordering_ = 0; | |
+ std::string ordering_filename(data_filename_); | |
+ // default ordering file name | |
+ ordering_filename.append(".ordering"); | |
+ TextReader<size_t> reader(ordering_filename.c_str(), false); | |
+ reader.ReadAllLines(); | |
+ if (reader.Lines().empty()) { | |
+ return; | |
+ } | |
+ Log::Info("Loading ordering..."); | |
+ num_ordering_ = static_cast<data_size_t>(reader.Lines().size()); | |
+ ordering_ = std::vector<label_t>(num_ordering_); | |
+#pragma omp parallel for schedule(static) | |
+ for (data_size_t i = 0; i < num_ordering_; ++i) { | |
+ double tmp_weight = 0.0f; | |
+ Common::Atof(reader.Lines()[i].c_str(), &tmp_weight); | |
+ ordering_[i] = static_cast<label_t>(tmp_weight); | |
+ } | |
+ ordering_load_from_file_ = true; | |
+} | |
+ | |
void Metadata::LoadWeights() { | |
num_weights_ = 0; | |
std::string weight_filename(data_filename_); | |
@@ -485,6 +569,10 @@ void Metadata::LoadFromMemory(const void* memory) { | |
num_queries_ = *(reinterpret_cast<const data_size_t*>(mem_ptr)); | |
mem_ptr += sizeof(num_queries_); | |
+ // TODO: to simplify implementation we do not load ordering from external | |
+ // data at this stage | |
+ num_ordering_ = 0; | |
+ | |
if (!label_.empty()) { label_.clear(); } | |
label_ = std::vector<label_t>(num_data_); | |
std::memcpy(label_.data(), mem_ptr, sizeof(label_t)*num_data_); | |
@@ -512,6 +600,9 @@ void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const { | |
writer->Write(&num_weights_, sizeof(num_weights_)); | |
writer->Write(&num_queries_, sizeof(num_queries_)); | |
writer->Write(label_.data(), sizeof(label_t) * num_data_); | |
+ | |
+ // TODO: to simplify implementation we do not load ordering from external | |
+ // data at this stage | |
if (!weights_.empty()) { | |
writer->Write(weights_.data(), sizeof(label_t) * num_weights_); | |
} | |
diff --git a/src/objective/objective_function.cpp b/src/objective/objective_function.cpp | |
index 9cf030a..77d5bb6 100644 | |
--- a/src/objective/objective_function.cpp | |
+++ b/src/objective/objective_function.cpp | |
@@ -31,6 +31,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& | |
return new BinaryLogloss(config); | |
} else if (type == std::string("lambdarank")) { | |
return new LambdarankNDCG(config); | |
+ } else if (type == std::string("multiobjlambdarank")) { | |
+ return new MultiObjLambdarankNDCG(config); | |
} else if (type == std::string("multiclass") || type == std::string("softmax")) { | |
return new MulticlassSoftmax(config); | |
} else if (type == std::string("multiclassova") || type == std::string("multiclass_ova") || type == std::string("ova") || type == std::string("ovr")) { | |
@@ -70,6 +72,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& | |
return new BinaryLogloss(strs); | |
} else if (type == std::string("lambdarank")) { | |
return new LambdarankNDCG(strs); | |
+ } else if (type == std::string("multiobjlambdarank")) { | |
+ return new MultiObjLambdarankNDCG(strs); | |
} else if (type == std::string("multiclass")) { | |
return new MulticlassSoftmax(strs); | |
} else if (type == std::string("multiclassova")) { | |
diff --git a/src/objective/rank_objective.hpp b/src/objective/rank_objective.hpp | |
index 785ac89..2df7559 100644 | |
--- a/src/objective/rank_objective.hpp | |
+++ b/src/objective/rank_objective.hpp | |
@@ -240,5 +240,233 @@ class LambdarankNDCG: public ObjectiveFunction { | |
double sigmoid_table_idx_factor_; | |
}; | |
+/*! | |
+* \brief Objective function combining LambdaRank NDCG over relevance label | |
+* and RankNet over an ordering (e.g. negative timestamp) | |
+*/ | |
+class MultiObjLambdarankNDCG: public ObjectiveFunction { | |
+ public: | |
+ explicit MultiObjLambdarankNDCG(const Config& config) { | |
+ sigmoid_ = static_cast<double>(config.sigmoid); | |
+ label_gain_ = config.label_gain; | |
+ // initialize DCG calculator | |
+ DCGCalculator::DefaultLabelGain(&label_gain_); | |
+ DCGCalculator::Init(label_gain_); | |
+ // will optimize NDCG@optimize_pos_at_ | |
+ optimize_pos_at_ = config.max_position; | |
+ sigmoid_table_.clear(); | |
+ inverse_max_dcgs_.clear(); | |
+ if (sigmoid_ <= 0.0) { | |
+ Log::Fatal("Sigmoid param %f should be greater than zero", sigmoid_); | |
+ } | |
+ } | |
+ | |
+ explicit MultiObjLambdarankNDCG(const std::vector<std::string>&) { | |
+ } | |
+ | |
+ ~MultiObjLambdarankNDCG() { | |
+ } | |
+ void Init(const Metadata& metadata, data_size_t num_data) override { | |
+ num_data_ = num_data; | |
+ // get label | |
+ label_ = metadata.label(); | |
+ DCGCalculator::CheckLabel(label_, num_data_); | |
+ // get ordering | |
+ ordering_ = metadata.ordering(); | |
+ // get weights | |
+ weights_ = metadata.weights(); | |
+ // get boundries | |
+ query_boundaries_ = metadata.query_boundaries(); | |
+ if (query_boundaries_ == nullptr) { | |
+ Log::Fatal("Lambdarank tasks require query information"); | |
+ } | |
+ num_queries_ = metadata.num_queries(); | |
+ // cache inverse max DCG, avoid computation many times | |
+ inverse_max_dcgs_.resize(num_queries_); | |
+#pragma omp parallel for schedule(static) | |
+ for (data_size_t i = 0; i < num_queries_; ++i) { | |
+ inverse_max_dcgs_[i] = DCGCalculator::CalMaxDCGAtK(optimize_pos_at_, | |
+ label_ + query_boundaries_[i], | |
+ query_boundaries_[i + 1] - query_boundaries_[i]); | |
+ | |
+ if (inverse_max_dcgs_[i] > 0.0) { | |
+ inverse_max_dcgs_[i] = 1.0f / inverse_max_dcgs_[i]; | |
+ } | |
+ } | |
+ // construct sigmoid table to speed up sigmoid transform | |
+ ConstructSigmoidTable(); | |
+ } | |
+ | |
+ void GetGradients(const double* score, score_t* gradients, | |
+ score_t* hessians) const override { | |
+ #pragma omp parallel for schedule(guided) | |
+ for (data_size_t i = 0; i < num_queries_; ++i) { | |
+ GetGradientsForOneQuery(score, gradients, hessians, i); | |
+ } | |
+ } | |
+ | |
+ inline void GetGradientsForOneQuery(const double* score, | |
+ score_t* lambdas, score_t* hessians, data_size_t query_id) const { | |
+ // get doc boundary for current query | |
+ const data_size_t start = query_boundaries_[query_id]; | |
+ const data_size_t cnt = | |
+ query_boundaries_[query_id + 1] - query_boundaries_[query_id]; | |
+ // get max DCG on current query | |
+ const double inverse_max_dcg = inverse_max_dcgs_[query_id]; | |
+ // add pointers with offset | |
+ const label_t* label = label_ + start; | |
+ score += start; | |
+ lambdas += start; | |
+ hessians += start; | |
+ // initialize with zero | |
+ for (data_size_t i = 0; i < cnt; ++i) { | |
+ lambdas[i] = 0.0f; | |
+ hessians[i] = 0.0f; | |
+ } | |
+ // get sorted indices for scores | |
+ std::vector<data_size_t> sorted_idx; | |
+ for (data_size_t i = 0; i < cnt; ++i) { | |
+ sorted_idx.emplace_back(i); | |
+ } | |
+ std::stable_sort(sorted_idx.begin(), sorted_idx.end(), | |
+ [score](data_size_t a, data_size_t b) { return score[a] > score[b]; }); | |
+ // get best and worst score | |
+ const double best_score = score[sorted_idx[0]]; | |
+ data_size_t worst_idx = cnt - 1; | |
+ if (worst_idx > 0 && score[sorted_idx[worst_idx]] == kMinScore) { | |
+ worst_idx -= 1; | |
+ } | |
+ const double wrost_score = score[sorted_idx[worst_idx]]; | |
+ // start accmulate lambdas by pairs | |
+ for (data_size_t i = 0; i < cnt; ++i) { | |
+ const data_size_t high = sorted_idx[i]; | |
+ const int high_label = static_cast<int>(label[high]); | |
+ const double high_score = score[high]; | |
+ if (high_score == kMinScore) { continue; } | |
+ const double high_label_gain = label_gain_[high_label]; | |
+ const double high_discount = DCGCalculator::GetDiscount(i); | |
+ double high_sum_lambda = 0.0; | |
+ double high_sum_hessian = 0.0; | |
+ for (data_size_t j = 0; j < cnt; ++j) { | |
+ // skip same data | |
+ if (i == j) { continue; } | |
+ | |
+ const data_size_t low = sorted_idx[j]; | |
+ const int low_label = static_cast<int>(label[low]); | |
+ const double low_score = score[low]; | |
+ // only consider pair with different label | |
+ if (high_label <= low_label || low_score == kMinScore) { continue; } | |
+ | |
+ const double delta_score = high_score - low_score; | |
+ | |
+ const double low_label_gain = label_gain_[low_label]; | |
+ const double low_discount = DCGCalculator::GetDiscount(j); | |
+ // get dcg gap | |
+ const double dcg_gap = high_label_gain - low_label_gain; | |
+ // get discount of this pair | |
+ const double paired_discount = fabs(high_discount - low_discount); | |
+ // get delta NDCG | |
+ double delta_pair_NDCG = dcg_gap * paired_discount * inverse_max_dcg; | |
+ // regular the delta_pair_NDCG by score distance | |
+ if (high_label != low_label && best_score != wrost_score) { | |
+ delta_pair_NDCG /= (0.01f + fabs(delta_score)); | |
+ } | |
+ // calculate lambda for this pair | |
+ double p_lambda = GetSigmoid(delta_score); | |
+ double p_hessian = p_lambda * (2.0f - p_lambda); | |
+ // update | |
+ p_lambda *= -delta_pair_NDCG; | |
+ p_hessian *= 2 * delta_pair_NDCG; | |
+ high_sum_lambda += p_lambda; | |
+ high_sum_hessian += p_hessian; | |
+ lambdas[low] -= static_cast<score_t>(p_lambda); | |
+ hessians[low] += static_cast<score_t>(p_hessian); | |
+ } | |
+ // update | |
+ lambdas[high] += static_cast<score_t>(high_sum_lambda); | |
+ hessians[high] += static_cast<score_t>(high_sum_hessian); | |
+ } | |
+ // if need weights | |
+ if (weights_ != nullptr) { | |
+ for (data_size_t i = 0; i < cnt; ++i) { | |
+ lambdas[i] = static_cast<score_t>(lambdas[i] * weights_[start + i]); | |
+ hessians[i] = static_cast<score_t>(hessians[i] * weights_[start + i]); | |
+ } | |
+ } | |
+ } | |
+ | |
+ | |
+ inline double GetSigmoid(double score) const { | |
+ if (score <= min_sigmoid_input_) { | |
+ // too small, use lower bound | |
+ return sigmoid_table_[0]; | |
+ } else if (score >= max_sigmoid_input_) { | |
+ // too big, use upper bound | |
+ return sigmoid_table_[_sigmoid_bins - 1]; | |
+ } else { | |
+ return sigmoid_table_[static_cast<size_t>((score - min_sigmoid_input_) * sigmoid_table_idx_factor_)]; | |
+ } | |
+ } | |
+ | |
+ void ConstructSigmoidTable() { | |
+ // get boundary | |
+ min_sigmoid_input_ = min_sigmoid_input_ / sigmoid_ / 2; | |
+ max_sigmoid_input_ = -min_sigmoid_input_; | |
+ sigmoid_table_.resize(_sigmoid_bins); | |
+ // get score to bin factor | |
+ sigmoid_table_idx_factor_ = | |
+ _sigmoid_bins / (max_sigmoid_input_ - min_sigmoid_input_); | |
+ // cache | |
+ for (size_t i = 0; i < _sigmoid_bins; ++i) { | |
+ const double score = i / sigmoid_table_idx_factor_ + min_sigmoid_input_; | |
+ sigmoid_table_[i] = 2.0f / (1.0f + std::exp(2.0f * score * sigmoid_)); | |
+ } | |
+ } | |
+ | |
+ const char* GetName() const override { | |
+ return "multiobjlambdarank"; | |
+ } | |
+ | |
+ std::string ToString() const override { | |
+ std::stringstream str_buf; | |
+ str_buf << GetName(); | |
+ return str_buf.str(); | |
+ } | |
+ | |
+ bool NeedAccuratePrediction() const override { return false; } | |
+ | |
+ private: | |
+ /*! \brief Gains for labels */ | |
+ std::vector<double> label_gain_; | |
+ /*! \brief Cache inverse max DCG, speed up calculation */ | |
+ std::vector<double> inverse_max_dcgs_; | |
+ /*! \brief Simgoid param */ | |
+ double sigmoid_; | |
+ /*! \brief Optimized NDCG@ */ | |
+ int optimize_pos_at_; | |
+ /*! \brief Number of queries */ | |
+ data_size_t num_queries_; | |
+ /*! \brief Number of data */ | |
+ data_size_t num_data_; | |
+ /*! \brief Pointer of label */ | |
+ const label_t* label_; | |
+ /*! \brief Pointer of ordering */ | |
+ const label_t* ordering_; | |
+ /*! \brief Pointer of weights */ | |
+ const label_t* weights_; | |
+ /*! \brief Query boundries */ | |
+ const data_size_t* query_boundaries_; | |
+ /*! \brief Cache result for sigmoid transform to speed up */ | |
+ std::vector<double> sigmoid_table_; | |
+ /*! \brief Number of bins in simoid table */ | |
+ size_t _sigmoid_bins = 1024 * 1024; | |
+ /*! \brief Minimal input of sigmoid table */ | |
+ double min_sigmoid_input_ = -50; | |
+ /*! \brief Maximal input of sigmoid table */ | |
+ double max_sigmoid_input_ = 50; | |
+ /*! \brief Factor that covert score to bin in sigmoid table */ | |
+ double sigmoid_table_idx_factor_; | |
+}; | |
+ | |
} // namespace LightGBM | |
#endif // LightGBM_OBJECTIVE_RANK_OBJECTIVE_HPP_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment