Created
March 14, 2021 06:41
-
-
Save ceshine/9da87120a58fe88600d5a4fb56ef9695 to your computer and use it in GitHub Desktop.
Streamlit Script that Cache the loading of a FAISS index (live at https://news-search.veritable.pw)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sqlite3 | |
import datetime | |
from typing import List | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import joblib | |
import requests | |
import streamlit as st | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
api_uri = os.environ.get("API_URI", "http://localhost:8666/") | |
@st.cache(allow_output_mutation=True) | |
def load_data(): | |
conn = sqlite3.connect("data/news.sqlite") | |
full_ids = joblib.load("data/ids.jbl") | |
index = faiss.read_index("data/index.faiss") | |
default_date_range = [datetime.date(2018, 11, 28), datetime.date.today()] | |
return conn, full_ids, index, default_date_range | |
def fetch_entries(conn: sqlite3.Connection, ids: List, date_range, scores): | |
cur = conn.cursor() | |
cur.execute( | |
"SELECT id, date, title, desc FROM entries " + | |
"WHERE id IN ({seq}) ".format( | |
seq=','.join(['?']*len(ids)) | |
) + | |
" AND date >= ? AND date <= ?;", | |
ids + [x.isoformat() for x in date_range] | |
) | |
results = pd.DataFrame( | |
cur.fetchall(), columns=["id", "date", "title", "desc"] | |
) | |
results["date"] = pd.to_datetime(results["date"]) | |
score_dict = {key: score for key, score in zip(ids, scores)} | |
results["score"] = results["id"].apply(lambda x: score_dict[x]) | |
results.sort_values("score", ascending=False, inplace=True) | |
return results | |
def get_latest_date(conn: sqlite3.Connection): | |
cur = conn.cursor() | |
cur.execute("SELECT MAX(date) FROM entries;") | |
return cur.fetchone()[0] | |
def get_embeddings(text: str): | |
response = requests.post(api_uri, json={"text": text}) | |
assert response.status_code == 200, response.text | |
return np.asarray(response.json()["vector"])[np.newaxis, :].astype("float32") | |
def main(): | |
st.title('Veritable News Semantic Search Engine') | |
query = st.text_area( | |
"Context/主題 (length > 10)", "", max_chars=256 | |
).replace("\n", " ") | |
conn, full_ids, index, default_date_range = load_data() | |
if st.button("Last 90 days"): | |
default_date_range[0] = datetime.date.today() - \ | |
datetime.timedelta(days=90) | |
if st.button("From the start"): | |
default_date_range[0] = datetime.date(2018, 11, 28) | |
date_range = st.date_input( | |
'Date Range/日期範圍', | |
value=default_date_range, | |
min_value=datetime.date(2018, 11, 28), | |
max_value=datetime.date.today() + datetime.timedelta(days=1) | |
) | |
sort_method = st.selectbox( | |
"Sort by:", ("relevance", "date (desc)", "date (asc)")) | |
if len(query) > 10 and len(date_range) == 2: | |
embs = get_embeddings(query) | |
faiss.normalize_L2(embs) | |
scores, index_matches = index.search(embs, k=100) | |
df_entries = fetch_entries( | |
conn, [full_ids[i] | |
for i in index_matches[0]], date_range, scores[0] | |
).iloc[:20] | |
if sort_method != "relevance": | |
df_entries.sort_values( | |
"date", ascending=sort_method == "date (asc)", inplace=True | |
) | |
for row in df_entries.values: | |
date, section, num = row[0].split("_") | |
st.write( | |
f"{date[:4]}/{date[4:6]}/{date[6:8]} " | |
f"[{row[2]}](https://news.veritable.pw/zh/piece/{date}/{section}_{num}) (score: {row[-1]:.4f})" | |
) | |
latest_date = get_latest_date(conn) | |
st.write( | |
f"Data updated on _{latest_date}_; Engine updated on _2021-03-07_") | |
st.write("_This app uses a TinyBERT-4L model to reduce hardware requirements. If you're interested in bigger and more powerful models, please e-mail **ceshine at veritable.pw**_") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment