Created
January 29, 2024 03:54
-
-
Save mooreniemi/70e26db0d892e5fdd3d67e56a2f0a06f to your computer and use it in GitHub Desktop.
example of ia-select alg
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::time::Instant; | |
/** | |
The IA-Select algorithm translated into Rust from | |
[Diversifying Search Results](https://www.microsoft.com/en-us/research/wp-content/uploads/2009/02/diversifying-wsdm09.pdf) | |
which is a "(1 − 1/`e`)-approximation algorithm for `Diversify(k)`." `Diversify(k)` is an NP-hard function that maximizes | |
the probability that you find a subset of documents from multiple categories of documents that "satisfies the average user." | |
IA stands for "intent aware" but intents are then mapped to categories. | |
*/ | |
fn ia_select<'a>( | |
num_documents_to_select: usize, | |
categories: &Vec<&str>, | |
doc_ids: &'a Vec<&'a str>, | |
probability_category_given_query: &HashMap<&str, f64>, | |
doc_id_category_score: &HashMap<(&str, &str), f64>, | |
) -> Vec<&'a str> { | |
let mut selected_doc_ids = Vec::new(); | |
let mut doc_ids = doc_ids.clone(); | |
let mut conditional_probability = probability_category_given_query.clone(); | |
while selected_doc_ids.len() < num_documents_to_select && !doc_ids.is_empty() { | |
let mut highest_marginal_utility = f64::MIN; | |
let mut candidate_doc_id = ""; | |
for &doc_id in &doc_ids { | |
let marginal_utility: f64 = categories | |
.iter() | |
.map(|&category| { | |
conditional_probability.get(category).unwrap_or(&0.0) | |
* doc_id_category_score | |
.get(&(doc_id, category)) | |
.unwrap_or(&0.0) | |
}) | |
.sum(); | |
if marginal_utility > highest_marginal_utility { | |
highest_marginal_utility = marginal_utility; | |
candidate_doc_id = doc_id; | |
} | |
} | |
// we selected a document, so we assume that we may have satisfied the query | |
// given that, further documents have diminishing returns so we set conditional probability downwards | |
if !candidate_doc_id.is_empty() { | |
selected_doc_ids.push(candidate_doc_id); | |
for &category in categories { | |
let quality = doc_id_category_score | |
.get(&(candidate_doc_id, category)) | |
.unwrap_or(&0.0); | |
let current_probability = conditional_probability.entry(category).or_insert(0.0); | |
*current_probability *= 1.0 - quality; | |
} | |
doc_ids.retain(|&d| d != candidate_doc_id); | |
} | |
} | |
selected_doc_ids | |
} | |
fn main() { | |
// we usually call this k | |
let num_documents_to_select = 5; | |
// we often say verticals, or just different retrievers here, depending on how you look at it | |
// the paper calls them categories so that's what I use here | |
let categories = vec!["recent", "all"]; | |
let doc_ids = vec![ | |
"doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8", "doc9", "doc10", | |
]; | |
// this is the signal we depend on | |
let probability_category_given_query: HashMap<&str, f64> = | |
[("recent", 0.9), ("all", 0.1)].iter().cloned().collect(); | |
// the ranking given by different rankers, these need to be normalized somehow | |
// otherwise a ranker with a higher overall distribution will dominate | |
// in the paper this is the "quality value" per category | |
let doc_id_category_score: HashMap<(&str, &str), f64> = [ | |
// only in recent | |
(("doc1", "recent"), 0.5), | |
(("doc1", "all"), 0.0), | |
// in both but more relevant in all | |
(("doc2", "recent"), 0.2), | |
(("doc2", "all"), 0.4), | |
// only in recent, the highest recent | |
(("doc3", "recent"), 0.8), | |
(("doc3", "all"), 0.0), | |
// in both, but more relevant in all (different margin) | |
(("doc4", "recent"), 0.1), | |
(("doc4", "all"), 0.4), | |
// in both, equally | |
(("doc5", "recent"), 0.3), | |
(("doc5", "all"), 0.3), | |
// only in all, low relevance | |
(("doc6", "recent"), 0.0), | |
(("doc6", "all"), 0.3), | |
// only in all, very high scored | |
(("doc7", "recent"), 0.0), | |
(("doc7", "all"), 0.9), | |
// in both but more relevant in recent (doc2 inverse) | |
(("doc8", "recent"), 0.4), | |
(("doc8", "all"), 0.2), | |
// in both but low (expect doc5 > doc9) | |
(("doc9", "recent"), 0.1), | |
(("doc9", "all"), 0.1), | |
// doc10 missing but handled anyway | |
] | |
.iter() | |
.cloned() | |
.collect(); | |
let now = Instant::now(); | |
let selected_doc_ids = ia_select( | |
num_documents_to_select, | |
&categories, | |
&doc_ids, | |
&probability_category_given_query, | |
&doc_id_category_score, | |
); | |
println!("Microseconds elapsed: {}", now.elapsed().as_micros()); | |
println!("Selected doc_ids: {:?} for {:?}", selected_doc_ids, probability_category_given_query); | |
let probability_category_given_query: HashMap<&str, f64> = | |
[("recent", 0.1), ("all", 0.9)].iter().cloned().collect(); | |
let now = Instant::now(); | |
let selected_doc_ids = ia_select( | |
num_documents_to_select, | |
&categories, | |
&doc_ids, | |
&probability_category_given_query, | |
&doc_id_category_score, | |
); | |
println!("Microseconds elapsed: {}", now.elapsed().as_micros()); | |
println!("Selected doc_ids: {:?} for {:?}", selected_doc_ids, probability_category_given_query); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output: