Skip to content

Instantly share code, notes, and snippets.

@koaning
Last active November 1, 2022 11:39
Show Gist options
  • Save koaning/81fc9433182ccfb9dece4bb4dbde1f7a to your computer and use it in GitHub Desktop.
Save koaning/81fc9433182ccfb9dece4bb4dbde1f7a to your computer and use it in GitHub Desktop.
DIET Benchmarks

readme

This gist contains the code to repeat the steps in the DIET benchmarking youtube video. You can download all the files by cloning this gist;

git clone [email protected]:81fc9433182ccfb9dece4bb4dbde1f7a.git

You'll also need to clone the repository over here to get the dataset you'll need. You can clone that repository via;

git clone [email protected]:RasaHQ/rasa-demo.git

You will also need to ensure that you've installed the bert dependencies if you want to run the heavy model.

pip install "rasa[transformers]"

Once that is done you can repeat everything we've done here by running;

mkdir results
rasa test nlu --config configs/config-orig.yml --cross-validation --runs 1 --folds 2 --out results/config-orig
rasa test nlu --config configs/config-init.yml --cross-validation --runs 1 --folds 2 --out results/config-init
rasa test nlu --config configs/diet-replace.yml --cross-validation --runs 1 --folds 2 --out results/diet-replace
rasa test nlu --config configs/diet-minimum.yml --cross-validation --runs 1 --folds 2 --out results/diet-minimum
rasa test nlu --config configs/diet-heavy.yml --cross-validation --runs 1 --folds 2 --out results/diet-heavy

Once done you can use streamlit to see a dasbboard of the results.

pip install streamlit
streamlit run viewresults.py
language: en
pipeline:
- name: WhitespaceTokenizer
- name: CountVectorsFeaturizer
- name: EmbeddingIntentClassifier
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 100
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
language: en
pipeline:
- name: WhitespaceTokenizer
- name: CRFEntityExtractor
- name: CountVectorsFeaturizer
OOV_token: oov
token_pattern: (?u)\b\w+\b
- name: CountVectorsFeaturizer
analyzer: char_wb
min_ngram: 1
max_ngram: 4
- name: EmbeddingIntentClassifier
epochs: 100
ranking_length: 5
- name: DucklingHTTPExtractor
url: http://localhost:8000
dimensions:
- email
- number
- amount-of-money
- name: EntitySynonymMapper
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 20
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
language: en
pipeline:
- name: HFTransformersNLP
model_weights: "bert-base-uncased"
model_name: "bert"
- name: LanguageModelTokenizer
- name: LanguageModelFeaturizer
- name: CountVectorsFeaturizer
analyzer: char_wb
min_ngram: 1
max_ngram: 4
- name: CountVectorsFeaturizer
- name: DIETClassifier
epochs: 100
num_transformer_layers: 4
transformer_size: 256
use_masked_language_model: True
drop_rate: 0.25
weight_sparsity: 0.7
batch_size: [64, 256]
embedding_dimension: 30
hidden_layer_sizes:
text: [512, 128]
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 20
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
language: en
pipeline:
- name: WhitespaceTokenizer
- name: CountVectorsFeaturizer
- name: CountVectorsFeaturizer
analyzer: char_wb
min_ngram: 1
max_ngram: 4
- name: DIETClassifier
epochs: 20
learning_rate: 0.005
num_transformer_layers: 0
embedding_dimension: 10
weight_sparcity: 0.90
hidden_layer_sizes:
text: [256, 128]
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 100
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
language: en
pipeline:
- name: WhitespaceTokenizer
- name: LexicalSyntacticFeaturizer
- name: CountVectorsFeaturizer
OOV_token: oov
token_pattern: (?u)\b\w+\b
- name: CountVectorsFeaturizer
analyzer: char_wb
min_ngram: 1
max_ngram: 4
- name: DIETClassifier
epochs: 100
ranking_length: 5
use_masked_language_model: True
- name: DucklingHTTPExtractor
url: http://localhost:8000
dimensions:
- email
- number
- amount-of-money
- name: EntitySynonymMapper
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 20
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
language: en
pipeline:
- name: WhitespaceTokenizer
- name: LexicalSyntacticFeaturizer
- name: CountVectorsFeaturizer
OOV_token: oov
token_pattern: (?u)\b\w+\b
- name: CountVectorsFeaturizer
analyzer: char_wb
min_ngram: 1
max_ngram: 4
- name: DIETClassifier
epochs: 100
ranking_length: 5
- name: DucklingHTTPExtractor
url: http://localhost:8000
dimensions:
- email
- number
- amount-of-money
- name: EntitySynonymMapper
policies:
- name: EmbeddingPolicy
max_history: 10
epochs: 20
batch_size:
- 32
- 64
- max_history: 6
name: AugmentedMemoizationPolicy
- core_threshold: 0.3
name: TwoStageFallbackPolicy
nlu_threshold: 0.8
- name: FormPolicy
- name: MappingPolicy
# to run this please make sure you've got the dependencies
# pip install streamlit altair pandas
import json
import pathlib
import streamlit as st
import altair as alt
import pandas as pd
def read_intent_report(path):
blob = json.loads(path.read_text())
jsonl = [{**v, 'config': path.parts[1]} for k,v in blob.items() if 'weighted avg' in k]
return pd.DataFrame(jsonl).drop(columns=['support'])
def read_entity_report(path):
blob = json.loads(path.read_text())
jsonl = [{**v, 'config': path.parts[1]} for k,v in blob.items() if 'weighted avg' in k]
return pd.DataFrame(jsonl).drop(columns=['support'])
def add_zeros(dataf, all_configs):
for cfg in all_configs:
if cfg not in list(dataf['config']):
dataf = pd.concat([dataf, pd.DataFrame({'precision': [0],
'recall': [0],
'f1-score': [0],
'config': cfg})])
return dataf
st.cache()
def read_pandas():
paths = list(pathlib.Path("results").glob("*/*_report.json"))
configurations = set([p.parts[1] for p in paths])
intent_df = pd.concat([read_intent_report(p) for p in paths if 'intent_report' in str(p)])
paths = list(pathlib.Path("results").glob("*/CRFEntityExtractor_report.json"))
paths += list(pathlib.Path("results").glob("*/DIETClassifier_report.json"))
entity_df = pd.concat([read_entity_report(p) for p in paths]).pipe(add_zeros, all_configs=configurations)
return intent_df, entity_df
intent_df, entity_df = read_pandas()
possible_configs = list(intent_df['config'])
st.markdown("# Rasa GridResults Summary")
st.markdown("Quick Overview of Crossvalidated Runs")
st.sidebar.markdown("### Configure Overview")
st.sidebar.markdown("Select what you care about.")
selected_config = st.sidebar.multiselect("Select Result Folders",
possible_configs,
default=possible_configs)
show_raw_data = st.sidebar.checkbox("Show Raw Data")
subset_df = intent_df.loc[lambda d: d['config'].isin(selected_config)].melt('config')
st.markdown("## Intent Summary Overview")
c = alt.Chart(subset_df).mark_bar().encode(
y='config:N',
x='value:Q',
color='config:N',
row='variable:N'
)
st.altair_chart(c)
if show_raw_data:
st.write(intent_df.loc[lambda d: d['config'].isin(selected_config)])
subset_df = entity_df.loc[lambda d: d['config'].isin(selected_config)].melt('config')
st.markdown("## Entity Summary Overview")
c = alt.Chart(subset_df).mark_bar().encode(
y='config:N',
x='value:Q',
color='config:N',
row='variable:N'
)
st.altair_chart(c)
if show_raw_data:
st.write(entity_df.loc[lambda d: d['config'].isin(selected_config)])
@koaning
Copy link
Author

koaning commented Jan 4, 2022

You can adapt the streamlit file to pick up other files that have been generated by Rasa. Note that this code assumes Rasa 1.x.

@reynoldms
Copy link

I change to the "DIETClassifier_report" to "intent_report" in the viewresults.py file and it works. Thanks @koaning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment