Skip to content

Instantly share code, notes, and snippets.

@kardeiz
Last active May 22, 2020 18:35
Show Gist options
  • Save kardeiz/d6bee7f84d0a015ed9a1bbd014b7b682 to your computer and use it in GitHub Desktop.
Save kardeiz/d6bee7f84d0a015ed9a1bbd014b7b682 to your computer and use it in GitHub Desktop.
`MoreLikeThis` query for `tantivy`
use std::collections::{BinaryHeap, HashMap};
use tantivy::{
query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery, Weight},
schema::{FieldType, IndexRecordOption, Schema, Term, Value},
tokenizer::{
BoxTokenStream, FacetTokenizer, PreTokenizedStream, StopWordFilter, TokenFilter, Tokenizer,
TokenizerManager,
},
Document, Searcher,
};
fn idf(doc_freq: u64, doc_count: u64) -> f32 {
let x = ((doc_count - doc_freq) as f32 + 0.5) / (doc_freq as f32 + 0.5);
(1f32 + x).ln()
}
#[derive(Debug, Clone)]
pub struct PerFieldTermFrequencies(HashMap<Term, usize>);
impl PerFieldTermFrequencies {
pub fn build(
schema: &Schema,
tokenizer_manager: &TokenizerManager,
stop_word_filter: Option<&StopWordFilter>,
doc: &Document,
) -> Self {
let mut map = HashMap::new();
for (field, field_values) in doc.get_sorted_field_values() {
let field_options = schema.get_field_entry(field);
if !field_options.is_indexed() {
continue;
}
match field_options.field_type() {
FieldType::HierarchicalFacet => {
let facets: Vec<&str> = field_values
.iter()
.map(|field_value| match *field_value.value() {
Value::Facet(ref facet) => facet.encoded_str(),
_ => {
panic!("Expected hierarchical facet");
}
})
.collect();
for fake_str in facets {
FacetTokenizer.token_stream(fake_str).process(&mut |token| {
let term = Term::from_field_text(field, &token.text);
*map.entry(term).or_insert(0) += 1;
});
}
}
FieldType::Str(text_options) => {
let mut token_streams: Vec<BoxTokenStream> = vec![];
let mut offsets = vec![];
let mut total_offset = 0;
for field_value in field_values {
match field_value.value() {
Value::PreTokStr(tok_str) => {
offsets.push(total_offset);
if let Some(last_token) = tok_str.tokens.last() {
total_offset += last_token.offset_to;
}
token_streams
.push(PreTokenizedStream::from(tok_str.clone()).into());
}
Value::Str(ref text) => {
if let Some(tokenizer) = text_options
.get_indexing_options()
.map(|text_indexing_options| {
text_indexing_options.tokenizer().to_string()
})
.and_then(|tokenizer_name| {
tokenizer_manager.get(&tokenizer_name)
})
{
offsets.push(total_offset);
total_offset += text.len();
token_streams.push(tokenizer.token_stream(text));
}
}
_ => (),
}
}
for mut token_stream in token_streams {
if let Some(stop_word_filter) = stop_word_filter {
token_stream = stop_word_filter.transform(token_stream);
}
token_stream.process(&mut |token| {
let term = Term::from_field_text(field, &token.text);
*map.entry(term).or_insert(0) += 1;
});
}
}
FieldType::U64(_) => {
for field_value in field_values {
let term = Term::from_field_u64(
field_value.field(),
field_value.value().u64_value(),
);
*map.entry(term).or_insert(0) += 1;
}
}
FieldType::Date(_) => {
for field_value in field_values {
let term = Term::from_field_i64(
field_value.field(),
field_value.value().date_value().timestamp(),
);
*map.entry(term).or_insert(0) += 1;
}
}
FieldType::I64(_) => {
for field_value in field_values {
let term = Term::from_field_i64(
field_value.field(),
field_value.value().i64_value(),
);
*map.entry(term).or_insert(0) += 1;
}
}
FieldType::F64(_) => {
for field_value in field_values {
let term = Term::from_field_f64(
field_value.field(),
field_value.value().f64_value(),
);
*map.entry(term).or_insert(0) += 1;
}
}
_ => {}
}
}
Self(map)
}
}
#[derive(PartialEq)]
pub struct ScoredTerm {
score: f32,
term: Term,
}
impl Eq for ScoredTerm {}
impl PartialOrd for ScoredTerm {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.score.partial_cmp(&other.score)
}
}
impl Ord for ScoredTerm {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Clone)]
struct StopWordFilterWrapper(StopWordFilter);
impl std::fmt::Debug for StopWordFilterWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StopWordFilter").finish()
}
}
#[derive(Debug, Clone)]
pub struct MoreLikeThisQuery {
min_term_freq: Option<usize>,
min_doc_freq: Option<u64>,
max_doc_freq: Option<u64>,
max_query_terms: Option<usize>,
min_word_len: Option<usize>,
max_word_len: Option<usize>,
boost: bool,
boost_factor: f32,
stop_word_filter: Option<StopWordFilterWrapper>,
doc: Document,
}
impl MoreLikeThisQuery {
pub fn builder() -> MoreLikeThisQueryBuilder {
MoreLikeThisQueryBuilder::default()
}
}
impl Query for MoreLikeThisQuery {
fn weight(
&self,
searcher: &Searcher,
scoring_enabled: bool,
) -> tantivy::Result<Box<dyn Weight>> {
let per_field_term_frequencies = PerFieldTermFrequencies::build(
searcher.schema(),
searcher.index().tokenizers(),
self.stop_word_filter.as_ref().map(|x| &x.0),
&self.doc,
);
let num_docs = searcher.segment_readers().iter().map(|x| x.num_docs() as u64).sum::<u64>();
let mut scored_terms = BinaryHeap::new();
for (term, term_freq) in per_field_term_frequencies.0.into_iter() {
if self.min_term_freq.map(|x| term_freq < x).unwrap_or(false) {
continue;
}
let term_value_len = term.value_bytes().len();
if self.min_word_len.map(|x| term_value_len < x).unwrap_or(false) {
continue;
}
if self.max_word_len.map(|x| term_value_len > x).unwrap_or(false) {
continue;
}
let doc_freq = searcher.doc_freq(&term);
if self.min_doc_freq.map(|x| doc_freq < x).unwrap_or(false) {
continue;
}
if self.max_doc_freq.map(|x| doc_freq > x).unwrap_or(false) {
continue;
}
if doc_freq == 0 {
continue;
}
let idf = idf(doc_freq, num_docs);
let score = (term_freq as f32) * idf;
scored_terms.push(ScoredTerm { term, score });
}
let top_score = scored_terms.peek().map(|x| x.score);
let mut scored_terms = scored_terms.into_sorted_vec();
let scored_terms = if let Some(max_query_terms) = self.max_query_terms {
let max_query_terms = std::cmp::min(max_query_terms, scored_terms.len());
scored_terms.drain(..max_query_terms)
} else {
scored_terms.drain(..)
};
let mut sub_queries = Vec::new();
for ScoredTerm { score, term } in scored_terms {
let mut query: Box<dyn Query> =
Box::new(TermQuery::new(term, IndexRecordOption::Basic));
if self.boost {
query = Box::new(BoostQuery::new(
query,
score * self.boost_factor / top_score.unwrap(),
));
}
sub_queries.push((Occur::Should, query));
}
let query = BooleanQuery::from(sub_queries);
query.weight(searcher, scoring_enabled)
}
}
#[derive(Debug, Clone)]
pub struct MoreLikeThisQueryBuilder {
min_term_freq: Option<usize>,
min_doc_freq: Option<u64>,
max_doc_freq: Option<u64>,
max_query_terms: Option<usize>,
min_word_len: Option<usize>,
max_word_len: Option<usize>,
boost: bool,
boost_factor: f32,
stop_word_filter: Option<StopWordFilterWrapper>,
}
impl Default for MoreLikeThisQueryBuilder {
fn default() -> Self {
Self {
min_term_freq: None,
min_doc_freq: Some(5),
max_doc_freq: None,
max_query_terms: Some(25),
min_word_len: None,
max_word_len: None,
boost: true,
boost_factor: 1.0,
stop_word_filter: Some(StopWordFilterWrapper(StopWordFilter::default())),
}
}
}
impl MoreLikeThisQueryBuilder {
pub fn with_min_term_freq(mut self, val: Option<usize>) -> Self {
self.min_term_freq = val;
self
}
pub fn with_min_doc_freq(mut self, val: Option<u64>) -> Self {
self.min_doc_freq = val;
self
}
pub fn with_max_doc_freq(mut self, val: Option<u64>) -> Self {
self.max_doc_freq = val;
self
}
pub fn with_max_query_terms(mut self, val: Option<usize>) -> Self {
self.max_query_terms = val;
self
}
pub fn with_min_word_len(mut self, val: Option<usize>) -> Self {
self.min_word_len = val;
self
}
pub fn with_max_word_len(mut self, val: Option<usize>) -> Self {
self.max_word_len = val;
self
}
pub fn with_boost(mut self, val: bool) -> Self {
self.boost = val;
self
}
pub fn with_boost_factor(mut self, val: f32) -> Self {
self.boost_factor = val;
self
}
pub fn with_stop_word_filter(mut self, val: Option<StopWordFilter>) -> Self {
self.stop_word_filter = val.map(StopWordFilterWrapper);
self
}
pub fn with_doc(self, val: Document) -> MoreLikeThisQuery {
let MoreLikeThisQueryBuilder {
min_term_freq,
min_doc_freq,
max_doc_freq,
max_query_terms,
min_word_len,
max_word_len,
boost,
boost_factor,
stop_word_filter,
} = self;
MoreLikeThisQuery {
min_term_freq,
min_doc_freq,
max_doc_freq,
max_query_terms,
min_word_len,
max_word_len,
boost,
boost_factor,
stop_word_filter,
doc: val,
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment