Created
February 23, 2017 05:57
-
-
Save Linusp/fc70d8a3ba5344ebeffa5075d6e21347 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| import pickle | |
| from math import sqrt | |
| from operator import itemgetter | |
| from collections import defaultdict, Counter | |
| import click | |
| PUNCTS_PAT = re.compile( | |
| r'(:?[#\$&@.,;:!?\'`"~_\+\-\*\/\\|\\^=<>\[\]\(\)\{\}]|' | |
| r'[\u2000-\u206f]|' | |
| r'[\u3000-\u303f])+' | |
| ) | |
| def make_ngram(text, gram_level=2): | |
| text = '^' + text + '$' | |
| res = [] | |
| for idx in range(0, len(text)-gram_level+1): | |
| gram = text[idx:idx+gram_level] | |
| if not PUNCTS_PAT.findall(gram.strip('^$')): | |
| res.append(gram) | |
| return res | |
| def clean(text): | |
| return text.strip() | |
| def cos_similarity(term_freq_a, term_freq_b): | |
| norm_a = 0 | |
| norm_b = 0 | |
| inner_product = 0 | |
| for term, freq in term_freq_a.items(): | |
| norm_a += freq ** 2 | |
| inner_product += freq * term_freq_b[term] | |
| for term, freq in term_freq_b.items(): | |
| norm_b += freq ** 2 | |
| if norm_a == 0 or norm_b == 0: | |
| return 0 | |
| return inner_product / sqrt(norm_a * norm_b) | |
| class NgramInvIndex(object): | |
| def __init__(self): | |
| """build inverted index with ngram method""" | |
| self._id2q = [] | |
| self._qa = defaultdict(set) | |
| self._index = defaultdict(set) | |
| self._inited = False | |
| def build_from_corpus(self, corpus): | |
| for ques, ans in corpus: | |
| ques = clean(ques) | |
| ans = clean(ans) | |
| if not ques or not ans: | |
| continue | |
| self._qa[ques].add(ans) | |
| for idx, ques in enumerate(self._qa.keys()): | |
| term_freq = Counter(make_ngram(clean(ques))) | |
| self._id2q.append((ques, term_freq)) | |
| for term in term_freq: | |
| self._index[term].add(idx) | |
| def build_from_file(self, fname, sep='\t'): | |
| corpus = [] | |
| for line in open(fname): | |
| line = line.strip() | |
| corpus.append(line.split('\t')) | |
| self.build_from_corpus(corpus) | |
| def retrieve(self, query, k=10): | |
| related = set() | |
| term_freq = Counter(make_ngram(clean(query))) | |
| for term in term_freq: | |
| for qid in self._index.get(term, []): | |
| related.add(qid) | |
| res = [] | |
| for qid in related: | |
| ques, ques_tf = self._id2q[qid] | |
| score = cos_similarity(ques_tf, term_freq) | |
| res.append((ques, self._qa[ques], score)) | |
| return sorted(res, key=itemgetter(2), reverse=True)[:k] | |
| def dump(self, fname): | |
| pickle.dump((self._id2q, self._qa, self._index), open(fname, 'wb')) | |
| def load(self, fname): | |
| self._id2q, self._qa, self._index = pickle.load(open(fname, 'rb')) | |
| @click.group() | |
| def cli(): | |
| pass | |
| @cli.command() | |
| @click.option("-i", "--infile", required=True) | |
| @click.option("-o", "--output", required=True) | |
| def build(infile, output): | |
| invindex = NgramInvIndex() | |
| invindex.build_from_file(infile) | |
| invindex.dump(output) | |
| @cli.command() | |
| @click.option("-i", "--index", required=True) | |
| @click.option("-l", "--limit", type=int, default=3) | |
| def qa(index, limit): | |
| invindex = NgramInvIndex() | |
| invindex.load(index) | |
| while True: | |
| query = input(">> ").strip() | |
| if query == 'exit': | |
| break | |
| for ques, ans, score in invindex.retrieve(query, k=limit): | |
| print(ques, ans, score) | |
| if __name__ == '__main__': | |
| cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment