Last active
May 22, 2020 18:35
-
-
Save kardeiz/d6bee7f84d0a015ed9a1bbd014b7b682 to your computer and use it in GitHub Desktop.
`MoreLikeThis` query for `tantivy`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::{BinaryHeap, HashMap}; | |
use tantivy::{ | |
query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery, Weight}, | |
schema::{FieldType, IndexRecordOption, Schema, Term, Value}, | |
tokenizer::{ | |
BoxTokenStream, FacetTokenizer, PreTokenizedStream, StopWordFilter, TokenFilter, Tokenizer, | |
TokenizerManager, | |
}, | |
Document, Searcher, | |
}; | |
fn idf(doc_freq: u64, doc_count: u64) -> f32 { | |
let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5); | |
(1f32 + x).ln() | |
} | |
#[derive(Debug, Clone)] | |
pub struct PerFieldTermFrequencies(HashMap<Term, usize>); | |
impl PerFieldTermFrequencies { | |
pub fn build( | |
schema: &Schema, | |
tokenizer_manager: &TokenizerManager, | |
stop_word_filter: Option<&StopWordFilter>, | |
doc: &Document, | |
) -> Self { | |
let mut map = HashMap::new(); | |
for (field, field_values) in doc.get_sorted_field_values() { | |
let field_options = schema.get_field_entry(field); | |
if !field_options.is_indexed() { | |
continue; | |
} | |
match field_options.field_type() { | |
FieldType::HierarchicalFacet => { | |
let facets: Vec<&str> = field_values | |
.iter() | |
.map(|field_value| match *field_value.value() { | |
Value::Facet(ref facet) => facet.encoded_str(), | |
_ => { | |
panic!("Expected hierarchical facet"); | |
} | |
}) | |
.collect(); | |
for fake_str in facets { | |
FacetTokenizer.token_stream(fake_str).process(&mut |token| { | |
let term = Term::from_field_text(field, &token.text); | |
*map.entry(term).or_insert(0) += 1; | |
}); | |
} | |
} | |
FieldType::Str(text_options) => { | |
let mut token_streams: Vec<BoxTokenStream> = vec![]; | |
let mut offsets = vec![]; | |
let mut total_offset = 0; | |
for field_value in field_values { | |
match field_value.value() { | |
Value::PreTokStr(tok_str) => { | |
offsets.push(total_offset); | |
if let Some(last_token) = tok_str.tokens.last() { | |
total_offset += last_token.offset_to; | |
} | |
token_streams | |
.push(PreTokenizedStream::from(tok_str.clone()).into()); | |
} | |
Value::Str(ref text) => { | |
if let Some(tokenizer) = text_options | |
.get_indexing_options() | |
.map(|text_indexing_options| { | |
text_indexing_options.tokenizer().to_string() | |
}) | |
.and_then(|tokenizer_name| { | |
tokenizer_manager.get(&tokenizer_name) | |
}) | |
{ | |
offsets.push(total_offset); | |
total_offset += text.len(); | |
token_streams.push(tokenizer.token_stream(text)); | |
} | |
} | |
_ => (), | |
} | |
} | |
for mut token_stream in token_streams { | |
if let Some(stop_word_filter) = stop_word_filter { | |
token_stream = stop_word_filter.transform(token_stream); | |
} | |
token_stream.process(&mut |token| { | |
let term = Term::from_field_text(field, &token.text); | |
*map.entry(term).or_insert(0) += 1; | |
}); | |
} | |
} | |
FieldType::U64(_) => { | |
for field_value in field_values { | |
let term = Term::from_field_u64( | |
field_value.field(), | |
field_value.value().u64_value(), | |
); | |
*map.entry(term).or_insert(0) += 1; | |
} | |
} | |
FieldType::Date(_) => { | |
for field_value in field_values { | |
let term = Term::from_field_i64( | |
field_value.field(), | |
field_value.value().date_value().timestamp(), | |
); | |
*map.entry(term).or_insert(0) += 1; | |
} | |
} | |
FieldType::I64(_) => { | |
for field_value in field_values { | |
let term = Term::from_field_i64( | |
field_value.field(), | |
field_value.value().i64_value(), | |
); | |
*map.entry(term).or_insert(0) += 1; | |
} | |
} | |
FieldType::F64(_) => { | |
for field_value in field_values { | |
let term = Term::from_field_f64( | |
field_value.field(), | |
field_value.value().f64_value(), | |
); | |
*map.entry(term).or_insert(0) += 1; | |
} | |
} | |
_ => {} | |
} | |
} | |
Self(map) | |
} | |
} | |
#[derive(PartialEq)] | |
pub struct ScoredTerm { | |
score: f32, | |
term: Term, | |
} | |
impl Eq for ScoredTerm {} | |
impl PartialOrd for ScoredTerm { | |
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { | |
self.score.partial_cmp(&other.score) | |
} | |
} | |
impl Ord for ScoredTerm { | |
fn cmp(&self, other: &Self) -> std::cmp::Ordering { | |
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal) | |
} | |
} | |
#[derive(Clone)] | |
struct StopWordFilterWrapper(StopWordFilter); | |
impl std::fmt::Debug for StopWordFilterWrapper { | |
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
f.debug_struct("StopWordFilter").finish() | |
} | |
} | |
#[derive(Debug, Clone)] | |
pub struct MoreLikeThisQuery { | |
min_term_freq: Option<usize>, | |
min_doc_freq: Option<u64>, | |
max_doc_freq: Option<u64>, | |
max_query_terms: Option<usize>, | |
min_word_len: Option<usize>, | |
max_word_len: Option<usize>, | |
boost: bool, | |
boost_factor: f32, | |
stop_word_filter: Option<StopWordFilterWrapper>, | |
doc: Document, | |
} | |
impl MoreLikeThisQuery { | |
pub fn builder() -> MoreLikeThisQueryBuilder { | |
MoreLikeThisQueryBuilder::default() | |
} | |
} | |
impl Query for MoreLikeThisQuery { | |
fn weight( | |
&self, | |
searcher: &Searcher, | |
scoring_enabled: bool, | |
) -> tantivy::Result<Box<dyn Weight>> { | |
let per_field_term_frequencies = PerFieldTermFrequencies::build( | |
searcher.schema(), | |
searcher.index().tokenizers(), | |
self.stop_word_filter.as_ref().map(|x| &x.0), | |
&self.doc, | |
); | |
let num_docs = searcher.segment_readers().iter().map(|x| x.num_docs() as u64).sum::<u64>(); | |
let mut scored_terms = BinaryHeap::new(); | |
for (term, term_freq) in per_field_term_frequencies.0.into_iter() { | |
if self.min_term_freq.map(|x| term_freq < x).unwrap_or(false) { | |
continue; | |
} | |
let term_value_len = term.value_bytes().len(); | |
if self.min_word_len.map(|x| term_value_len < x).unwrap_or(false) { | |
continue; | |
} | |
if self.max_word_len.map(|x| term_value_len > x).unwrap_or(false) { | |
continue; | |
} | |
let doc_freq = searcher.doc_freq(&term); | |
if self.min_doc_freq.map(|x| doc_freq < x).unwrap_or(false) { | |
continue; | |
} | |
if self.max_doc_freq.map(|x| doc_freq > x).unwrap_or(false) { | |
continue; | |
} | |
if doc_freq == 0 { | |
continue; | |
} | |
let idf = idf(doc_freq, num_docs); | |
let score = (term_freq as f32) * idf; | |
scored_terms.push(ScoredTerm { term, score }); | |
} | |
let top_score = scored_terms.peek().map(|x| x.score); | |
let mut scored_terms = scored_terms.into_sorted_vec(); | |
let scored_terms = if let Some(max_query_terms) = self.max_query_terms { | |
let max_query_terms = std::cmp::min(max_query_terms, scored_terms.len()); | |
scored_terms.drain(..max_query_terms) | |
} else { | |
scored_terms.drain(..) | |
}; | |
let mut sub_queries = Vec::new(); | |
for ScoredTerm { score, term } in scored_terms { | |
let mut query: Box<dyn Query> = | |
Box::new(TermQuery::new(term, IndexRecordOption::Basic)); | |
if self.boost { | |
query = Box::new(BoostQuery::new( | |
query, | |
score * self.boost_factor / top_score.unwrap(), | |
)); | |
} | |
sub_queries.push((Occur::Should, query)); | |
} | |
let query = BooleanQuery::from(sub_queries); | |
query.weight(searcher, scoring_enabled) | |
} | |
} | |
#[derive(Debug, Clone)] | |
pub struct MoreLikeThisQueryBuilder { | |
min_term_freq: Option<usize>, | |
min_doc_freq: Option<u64>, | |
max_doc_freq: Option<u64>, | |
max_query_terms: Option<usize>, | |
min_word_len: Option<usize>, | |
max_word_len: Option<usize>, | |
boost: bool, | |
boost_factor: f32, | |
stop_word_filter: Option<StopWordFilterWrapper>, | |
} | |
impl Default for MoreLikeThisQueryBuilder { | |
fn default() -> Self { | |
Self { | |
min_term_freq: None, | |
min_doc_freq: Some(5), | |
max_doc_freq: None, | |
max_query_terms: Some(25), | |
min_word_len: None, | |
max_word_len: None, | |
boost: true, | |
boost_factor: 1.0, | |
stop_word_filter: Some(StopWordFilterWrapper(StopWordFilter::default())), | |
} | |
} | |
} | |
impl MoreLikeThisQueryBuilder { | |
pub fn with_min_term_freq(mut self, val: Option<usize>) -> Self { | |
self.min_term_freq = val; | |
self | |
} | |
pub fn with_min_doc_freq(mut self, val: Option<u64>) -> Self { | |
self.min_doc_freq = val; | |
self | |
} | |
pub fn with_max_doc_freq(mut self, val: Option<u64>) -> Self { | |
self.max_doc_freq = val; | |
self | |
} | |
pub fn with_max_query_terms(mut self, val: Option<usize>) -> Self { | |
self.max_query_terms = val; | |
self | |
} | |
pub fn with_min_word_len(mut self, val: Option<usize>) -> Self { | |
self.min_word_len = val; | |
self | |
} | |
pub fn with_max_word_len(mut self, val: Option<usize>) -> Self { | |
self.max_word_len = val; | |
self | |
} | |
pub fn with_boost(mut self, val: bool) -> Self { | |
self.boost = val; | |
self | |
} | |
pub fn with_boost_factor(mut self, val: f32) -> Self { | |
self.boost_factor = val; | |
self | |
} | |
pub fn with_stop_word_filter(mut self, val: Option<StopWordFilter>) -> Self { | |
self.stop_word_filter = val.map(StopWordFilterWrapper); | |
self | |
} | |
pub fn with_doc(self, val: Document) -> MoreLikeThisQuery { | |
let MoreLikeThisQueryBuilder { | |
min_term_freq, | |
min_doc_freq, | |
max_doc_freq, | |
max_query_terms, | |
min_word_len, | |
max_word_len, | |
boost, | |
boost_factor, | |
stop_word_filter, | |
} = self; | |
MoreLikeThisQuery { | |
min_term_freq, | |
min_doc_freq, | |
max_doc_freq, | |
max_query_terms, | |
min_word_len, | |
max_word_len, | |
boost, | |
boost_factor, | |
stop_word_filter, | |
doc: val, | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment