Created
February 27, 2023 19:58
-
-
Save wesslen/4a34ab1c25dbe1215c0f82ca8ab5c0ce to your computer and use it in GitHub Desktop.
Python script for Prodigy NER dataset viewer using Streamlit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Example of a Streamlit app for an interactive Prodigy NER dataset viewer. | |
| Requires the Prodigy annotation tool to be installed: https://prodi.gy | |
| See here for details on Streamlit: https://streamlit.io. | |
| """ | |
| import streamlit as st | |
| from prodigy.components.db import connect | |
| from prodigy.models.ner import merge_spans | |
| import pandas as pd | |
| import spacy | |
| import datetime | |
| from spacy import displacy | |
| from spacy.util import filter_spans | |
| SPACY_MODEL_NAMES = ["en_core_web_sm"] | |
| EXC_FIELDS = ["meta", "priority", "score"] | |
| HTML_WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>""" | |
| COLOR_ACCEPT = "#93eaa1" | |
| COLOR_REJECT = "#ff8f8e" | |
| def keep_answers(examples): | |
| records = [] | |
| for eg in examples: | |
| if eg.get("answer") in ["answer","reject","ignore"]: | |
| records.append(eg) | |
| return records | |
| def get_session_ids(examples): | |
| session_ids = set() | |
| for eg in examples: | |
| if eg.get("answer") in ["accept","reject","ignore"]: | |
| session_ids.add(eg.get("_session_id")) | |
| return session_ids | |
| def get_answer_counts(examples): | |
| result = {"accept": 0, "reject": 0, "ignore": 0} | |
| for eg in examples: | |
| answer = eg.get("answer") | |
| if answer: | |
| result[answer] += 1 | |
| return result | |
| def format_label(label, answer="accept"): | |
| # Hack to use different colors for the label (by adding zero-width space) | |
| return f"{label}\u200B" if answer == "reject" else label | |
| def get_time_utc(timestamp_str): | |
| return datetime.datetime.fromtimestamp(int(timestamp_str)) | |
| st.sidebar.title("Prodigy Data Explorer") | |
| db = connect() | |
| db_sets = db.datasets | |
| placeholder = "Select dataset..." | |
| dataset = st.sidebar.selectbox(f"Datasets ({len(db_sets)})", [placeholder] + db_sets) | |
| st.title("Prodigy's annotated datasets") | |
| if dataset == placeholder: | |
| db_counts=[] | |
| for dbs in db_sets: | |
| examples = db.get_dataset(dbs) | |
| db_counts.append( | |
| { | |
| "db": dbs, | |
| "size": db.count_dataset(dbs), | |
| "tasks": len(set(db.get_task_hashes(dbs))), | |
| "inputs": len(set(db.get_input_hashes(dbs))), | |
| "session_ids": len(get_session_ids(examples)), | |
| "timestamp": "No datetime" if examples[0].get("_timestamp") is None else get_time_utc(examples[0].get("_timestamp")) | |
| } | |
| ) | |
| df=pd.DataFrame(db_counts) | |
| st.dataframe(df.sort_values(by=['size'], ascending=False), height=500) | |
| if dataset != placeholder: | |
| examples = db.get_dataset(dataset) | |
| st.header(f"{dataset} ({len(examples)})") | |
| if not len(examples): | |
| st.markdown("_Empty dataset._") | |
| else: | |
| counts = get_answer_counts(examples) | |
| st.markdown(", ".join(f"**{c}** {a}" for a, c in counts.items())) | |
| fields = list(examples[0].keys()) | |
| task_fields = st.sidebar.multiselect("Visible fields", fields, fields) | |
| df_eg = pd.DataFrame(examples).filter(task_fields) | |
| df_eg['_timestamp'] = pd.to_datetime(df_eg['_timestamp'], unit='s') | |
| st.dataframe(df_eg, height=500) | |
| spacy_model = st.sidebar.selectbox("spaCy model", SPACY_MODEL_NAMES) | |
| if len((get_session_ids(examples)))>0: | |
| session_id = st.sidebar.selectbox("Select Session ID", get_session_ids(examples)) | |
| st.subheader("Named entity viewer") | |
| nlp = spacy.load(spacy_model) | |
| merged_examples = merge_spans(list(examples)) | |
| all_labels = set() | |
| for eg in merged_examples: | |
| if eg.get("_session_id") == session_id: | |
| for span in eg["spans"]: | |
| all_labels.add(span["label"]) | |
| elif len((get_session_ids(examples)))==0: | |
| for span in eg["spans"]: | |
| all_labels.add(span["label"]) | |
| colors = {} | |
| for label in all_labels: | |
| colors[label] = COLOR_ACCEPT | |
| colors[format_label(label, "reject")] = COLOR_REJECT | |
| title = f"Merged examples ({len(merged_examples)})" | |
| if session_id is not None: | |
| title = title + " Session id = " + str(session_id) | |
| ner_example_i = st.selectbox( | |
| title, | |
| range(len(merged_examples)), | |
| format_func=lambda i: merged_examples[int(i)]["text"][:400], # keeps only first 400 just in case | |
| ) | |
| ner_example = merged_examples[int(ner_example_i)] | |
| doc = nlp.make_doc(ner_example["text"]) | |
| ents = [] | |
| for span in ner_example.get("spans", []): | |
| if span.get("answer") is not None: | |
| label = format_label(span["label"], span["answer"]) | |
| else: | |
| label = format_label(span["label"], span) | |
| ents.append(doc.char_span(span["start"], span["end"], label=label)) | |
| doc.ents = filter_spans(ents) | |
| html = displacy.render(doc, style="ent", options={"colors": colors}) | |
| html = html.replace("\n", " ") # Newlines seem to mess with the rendering | |
| st.write(HTML_WRAPPER.format(html), unsafe_allow_html=True) | |
| show_ner_example_json = st.checkbox("Show JSON example") | |
| if show_ner_example_json: | |
| st.json(ner_example) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment