Created
July 21, 2019 00:40
-
-
Save allanj/505ec8ee873570cc53df6127daec6fda to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
# To make things easier later, we're also importing numpy and pandas for working with sample data. | |
import numpy | |
import pandas | |
# Don't worry, we'll explain this method in the next section. We need to make at least one | |
# call to Streamlit in order to generate a report. | |
st.title("Demo Test") | |
# streamlit.header("I'm a large heading") | |
# streamlit.subheader("I'm not a large heading") | |
st.markdown("**AllenNLP Reading Comprehension demo**") | |
st.markdown('**Passage**:') | |
passage_pretext = 'The Matrix is a 1999 science fiction action film.' | |
passage =st.text_area('Passage Input:', passage_pretext) | |
import allennlp | |
from allennlp.predictors.predictor import Predictor | |
@st.cache | |
def load_predictor(): | |
predictor = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/bidaf-model-2017.09.15-charpad.tar.gz") | |
return predictor | |
predictor = load_predictor() | |
q = "When is the Matrix written?" | |
question = st.text_input('question:', q) | |
# st.write(passage) | |
# st.write(question) | |
answer= predictor.predict(passage = passage, question = question) | |
best_span = answer["best_span_str"] | |
best_span_idx = answer["best_span"] | |
st.write(answer) | |
st.write(best_span) | |
def in_span(pos: int, idxs): | |
if pos >= idxs[0] and pos <= idxs[1]: | |
return True | |
return False | |
str = ["**"+x+"**" if in_span(pos, best_span_idx) else x for pos,x in enumerate(answer["passage_tokens"])] | |
st.markdown(" ".join(str)) | |
import matplotlib.pyplot as plt | |
import numpy as np | |
attention_array = np.asarray(answer["passage_question_attention"]) | |
fig, ax = plt.subplots() | |
im = ax.imshow(attention_array) | |
st.pyplot() | |
st.header("Address Parsing Demo") | |
address = st.text_input('address:', value='广东省广州市') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment