Skip to content

Instantly share code, notes, and snippets.

@rgstephens
Created October 21, 2023 15:39
Show Gist options
  • Save rgstephens/db4f953c9ae439e8a56a6a289902cc22 to your computer and use it in GitHub Desktop.
Save rgstephens/db4f953c9ae439e8a56a6a289902cc22 to your computer and use it in GitHub Desktop.
Scores
# Rasa Technologies GmbH
# Copyright 2021 Rasa Technologies GmbH
import argparse
import os
from typing import Text, Any, Dict, Optional, Tuple, Union
import shutil
from subprocess import call, STDOUT
import requests
import logging
import questionary
from urllib.parse import urljoin
import datetime
from base64 import b64encode
from pathlib import Path
from itertools import groupby
from rasa.shared.utils.cli import print_success, print_error_and_exit, print_info
from rasa.shared.utils.io import read_yaml, read_yaml_file, write_yaml
from rasa.shared.data import get_data_files
# from rasa.nlu.config import load as load_nlu_config
from rasa.shared.core.constants import DEFAULT_INTENTS
import rasa.cli.utils
from rasa.shared.importers.importer import TrainingDataImporter
from rasa.shared.constants import (
CONFIG_SCHEMA_FILE,
DEFAULT_E2E_TESTS_PATH,
DEFAULT_CONFIG_PATH,
DEFAULT_DOMAIN_PATH,
DEFAULT_MODELS_PATH,
DEFAULT_DATA_PATH,
DEFAULT_RESULTS_PATH,
)
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
YAMLStoryReader,
)
from rasa.shared.core.training_data.structures import RuleStep, StoryStep
# ===========================================#
# These are internal settings that probably #
# shouldn't be changed #
# -------------------------------------------#
RASA_SCORECARD_URL = "https://www.enterprise-scorecard.rasa.com/v1/"
# -------------------------------------------#
DEFAULT_SCORECARD = """
version: 1.0
scorecard:
manual:
CI:
ci_runs_data_validation: false
ci_trains_model: false
ci_runs_rasa_test: false
test_dir: tests
ci_builds_action_server: false
CD:
ci_deploys_action_server: false
infrastructure_as_code: false
has_test_environment: false
ci_runs_vulnerability_scans: false
training_data_health:
connected_channels: []
success_kpis:
automated_conversation_tags: []
auto:
CI:
code_and_training_data_in_git: false
rasa_test_coverage: 0.0
training_data_health:
num_min_example_warnings: null
num_confused_intent_warnings: null
num_confidence_warnings: null
num_wrong_annotation_warnings: null
num_precision_recall_warnings: null
num_real_conversations: 0
num_annotated_messages: 0
success_kpis:
has_fallback: false
num_reviewed_conversations: 0
num_tagged_conversations: 0
training_data:
num_intents: 0
num_entities: 0
num_actions: 0
num_slots: 0
num_responses: 0
num_forms: 0
num_retrieval_intents: 0
num_stories: 0
num_rules: 0
avg_story_len: 0
avg_rule_len: 0
avg_entity_examples: 0
avg_intent_examples: 0
"""
DataType = Dict[Text, Any]
def _create_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Collect data for scorecard.")
parser.add_argument(
"--username",
"--user",
help=(
"Username for Rasa Enterprise. Can be set with the environment variable"
" `RASA_ENTERPRISE_USER`"
),
default=os.environ.get("RASA_ENTERPRISE_USER"),
)
parser.add_argument(
"--password",
"--pass",
help=(
"Password for Rasa Enterprise. Can be set with the environment variable"
" `RASA_ENTERPRISE_PASSWORD`"
),
default=os.environ.get("RASA_ENTERPRISE_PASSWORD"),
)
parser.add_argument(
"--url",
help=(
"Host URL for Rasa Enterprise. Can be set with the environment variable"
" `RASA_ENTERPRISE_URL`"
),
default=os.environ.get("RASA_ENTERPRISE_URL"),
)
parser.add_argument(
"--file",
required=False,
help="File containing your scores. Will be overwritten!",
)
parser.add_argument(
"--debug", required=False, action="store_true", help="Turn on verbose logging."
)
parser.add_argument(
"--skip-wizard",
required=False,
action="store_true",
help="Skip running the wizard. User will not be prompted.",
)
parser.add_argument(
"--skip-api",
required=False,
action="store_true",
help=(
"Skip fetching data from the Rasa Enterprise API. Useful if you haven't"
" deployed Rasa Enterprise yet."
),
)
parser.add_argument(
"--on-the-fly-insights",
required=False,
nargs="?",
const="results",
help=(
"Calculate NLU insights on-the-fly based on cross-validation results."
" Will skip fetching NLU Insights from the Rasa Enterprise API."
" Accepts an optional path to results, defaults to ./results."
),
)
parser.add_argument(
"view",
nargs="?",
help=(
"Load an existing scores file to generate a shareable link"
" and display scorecard in the browser."
" Skips fetching data or calculating scores."
),
)
return parser
def read_existing_scores_from_file(filename: str) -> DataType:
try:
data = read_yaml_file(filename)
except: # noqa: E722
print_error_and_exit(f"Could not read file {filename} as yaml.")
return data
def _create_backup_of_file(filename: str) -> str:
date = datetime.date.today().strftime("%Y%m%d")
backup_name = f"backup_{date}_{filename}"
shutil.copy(filename, backup_name)
return backup_name
def _is_cwd_git_dir() -> bool:
logging.debug("Checking there is a git repo in this directory.")
try:
return call(["git", "branch"], stderr=STDOUT, stdout=open(os.devnull, "w")) == 0
except: # noqa: E722
git_path = Path(__file__).parent.resolve() / "./.git"
logging.debug(f"git command not available. Checking existence of: {git_path}")
return git_path.exists()
def _has_fallback() -> bool:
logging.debug("Checking if there is a fallback classifier in NLU pipeline.")
try:
config = str(
rasa.cli.utils.get_validated_path("config.yml", "config", DEFAULT_CONFIG_PATH)
)
config_importer = TrainingDataImporter.load_from_dict(config_path=config)
config_dict = config_importer.get_config()
# config = load_nlu_config("./config.yml")
logging.debug(f"NLU config has components: {config_dict['pipeline']}.")
# return "FallbackClassifier" in config_dict.component_names
contains_fallback = any(d.get('name') == 'Fallback' for d in config_dict['pipeline'])
return contains_fallback
except: # noqa: E722
logging.error("could not load config from file: config.yml")
return False
def update_scores(
data: DataType,
token: Optional[str],
url: Optional[str],
crossval_results: Optional[str],
) -> DataType:
data = update_data_from_repo(data)
if token and url:
data = update_data_from_rasa_enterprise(data, token, url, crossval_results)
elif crossval_results:
warnings = count_warnings(url, token, crossval_results)
data["scorecard"]["auto"]["training_data_health"].update(warnings)
return data
def get_auth_token(
user: Optional[str], password: Optional[str], url: Optional[str]
) -> Optional[str]:
if not user or not password or not url:
print_error_and_exit(
"Rasa Enterprise login credentials not set. Run the script with `--help`"
" for more information."
)
url = urljoin(url, "/api/auth")
payload = {"username": user, "password": password}
response = requests.post(url, json=payload)
try:
token = response.json()["access_token"] # type: str
return token
except: # noqa: E722
logging.error(
"Unable to authenticate against Rasa Enterprise API. Error"
f" {response.status_code}"
)
logging.error(response.text)
return None
def count_tagged_conversations(url: str, token: str, tags: bool = None) -> int:
logging.debug("Counting conversations with tags.")
url = urljoin(url, "/api/data_tags")
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(url, headers=headers)
total = 0
try:
data_tags = response.json()
except: # noqa: E722
logging.error(
"Error calling Rasa Enterprise API to count tagged conversations."
)
logging.error(response)
return total
for tag in data_tags:
if tag["value"] in tags:
num_conversations = len(tag["conversations"])
logging.debug(
f"Found {num_conversations} conversations tagged with {tag['value']}."
)
total += num_conversations
return total
def count_conversations(
url: str, token: str, unread: bool = False, rasa_channel: bool = False
) -> int:
params = {"limit": 1, "exclude_self": True}
if unread:
params["review_status"] = "unread"
if rasa_channel:
params["input_channels"] = "rasa"
logging.debug(f"Counting conversations with params: {params}.")
url = urljoin(url, "/api/conversations")
headers = {"Authorization": f"Bearer {token}"}
response = None
try:
response = requests.get(url, headers=headers, params=params)
num_conversations = int(response.headers.get("X-Total-Count", 0))
return num_conversations
except: # noqa: E722
logging.error("Error calling Rasa Enterprise API to count conversations.")
logging.error(response)
return 0
def count_messages(url: str, token: str, exclude_training_data: bool = False) -> int:
params = {"limit": 1, "exclude_training_data": exclude_training_data}
url = urljoin(url, "/api/projects/default/logs")
headers = {"Authorization": f"Bearer {token}"}
response = None
try:
response = requests.get(url, headers=headers, params=params)
num_messages = int(response.headers.get("X-Total-Count", 0))
return num_messages
except: # noqa: E722
logging.error("Error calling Rasa Enterprise API to count messages.")
logging.error(response)
return 0
insight_key_map = {
"MinimumExampleInsightCalculator": "num_min_example_warnings",
"ConfusionInsightCalculator": "num_confused_intent_warnings",
"NLUInboxConfidenceInsightCalculator": "num_confidence_warnings",
"WrongAnnotationInsightCalculator": "num_wrong_annotation_warnings",
"ClassBiasInsightCalculator": "num_precision_recall_warnings",
}
def calculate_insights_for_ci(
linter_config: DataType = None, test_results: str = str(Path("results"))
) -> DataType:
"""Calculates NLU insights based on local cross-validation results"""
import asyncio
import contextlib
from rasax.community.services.insights.insight_calculator import (
_get_insights_for_cli,
)
context_manager = contextlib.nullcontext(None)
with context_manager as session:
print_info(
"Calculating intent insights based on local cross-validation results ... 🔎"
)
loop = asyncio.get_event_loop()
warnings = loop.run_until_complete(
_get_insights_for_cli(linter_config, test_results, session,)
)
if warnings:
logging.debug("Suggestions to improve training data were found 🚨")
else:
logging.debug("No suggestions to improve training data were found")
return warnings
def count_warnings(url: str, token: str, crossval_results: Optional[str]) -> DataType:
result = {
"num_min_example_warnings": None,
"num_confused_intent_warnings": None,
"num_confidence_warnings": None,
"num_wrong_annotation_warnings": None,
"num_precision_recall_warnings": None,
}
headers = {"Authorization": f"Bearer {token}"}
url = urljoin(url, "/api/insights/nlu")
version_url = urljoin(url, "/api/version")
try:
if crossval_results:
calculated_warnings = calculate_insights_for_ci(
test_results=crossval_results
)
warnings = [
{
"source": insight.source,
"details": insight.details,
"description": insight.description,
}
for intent in calculated_warnings.values()
for insight in intent
]
if warnings:
logging.debug(
"NLU insight reports present, replacing default 'null' warning"
" counts with 0"
)
result = dict.fromkeys(result, 0)
else:
version_response = requests.get(version_url, headers=headers)
version = version_response.json().get("rasa-x")
if version < "0.36.0":
logging.warning(
"Rasa X version is < 0.36 so cannot retrieve NLU Insights"
)
return dict.fromkeys(result, None)
response = requests.get(url, headers=headers)
if response.status_code == 404:
logging.warning("Failed to retrieve NLU Insights from Rasa X API")
return dict.fromkeys(result, None)
insight_reports = response.json()
latest_report = next(
(
report
for report in reversed(insight_reports)
if report["status"] == "success"
),
None,
)
if latest_report is None:
logging.error(
"No successful NLU insight reports found on Rasa Enterprise API"
)
return dict.fromkeys(result, None)
else:
logging.debug(
"NLU insight reports present, replacing default 'null' warning"
" counts with 0"
)
result = dict.fromkeys(result, 0)
url = urljoin(url, f"nlu/{latest_report['id']}")
response = requests.get(url, headers=headers)
report = response.json()
warnings = [
warnings
for intent in report["intent_evaluation_results"]
for warnings in intent["intent_insights"]
]
warnings.sort(key=lambda w: w["source"])
for warning_type, group in groupby(warnings, lambda w: w["source"]):
key = insight_key_map.get(warning_type)
if key:
result[key] = len(list(group))
return result
except: # noqa: E722
logging.error("Error processing NLU insight report")
return dict.fromkeys(result, None)
def load_domain(domain_path: Path = Path(".")):
from rasa.shared.core.domain import Domain
domain_roots = []
if domain_path.is_file():
domain_roots = [domain_path]
else:
# Ignore the data folder because responses are considered valid domain files
# Really we need to fix `rasa` to better detect these files so we can just throw
# a directory to Domain.from_path() to walk
domain_roots = []
for p in domain_path.iterdir():
if not p.name.startswith(".") and p.is_file() and p.name != "data":
try:
if Domain.is_domain_file(p):
domain_roots.append(p)
except UnicodeDecodeError:
# Rasa sometimes throws a UnicodeDecodeError
# just skip these yml files as candidate domain files
# possible causes:
# (-) issue 102: there is a `\` in the yml file
pass
domain = Domain.empty()
for d in domain_roots:
domain = domain.merge(Domain.from_path(d))
return domain
def calculate_training_data(data: DataType) -> float:
logging.debug(
"Calculating training data totals."
)
result = {
# domain
"num_intents": 0,
"num_entities": 0,
"num_actions": 0,
"num_slots": 0,
"num_responses": 0,
"num_forms": 0,
"num_retrieval_intents": 0,
"num_stories": 0,
"num_rules": 0,
"avg_story_len": 0,
"avg_rule_len": 0,
"avg_entity_examples": 0,
"avg_intent_examples": 0,
}
config = str(
rasa.cli.utils.get_validated_path("config.yml", "config", DEFAULT_CONFIG_PATH)
)
domain = rasa.cli.utils.get_validated_path(
"domain.yml", "domain", DEFAULT_DOMAIN_PATH, none_is_valid=False
)
data_path = rasa.cli.utils.get_validated_path("data", "nlu", DEFAULT_DATA_PATH)
# data_path = rasa.shared.data.get_nlu_directory(data_path)
training_files = rasa.shared.data.get_data_files(
data_path, YAMLStoryReader.is_stories_file
)
training_files += rasa.shared.data.get_data_files(
data_path, rasa.shared.data.is_nlu_file
)
# config_importer = TrainingDataImporter.load_from_dict(config_path=config)
# Args:
# config: Path to the config file.
# domain: Path to the domain file.
# training_files: List of paths to training data files.
file_importer = TrainingDataImporter.load_from_config(
domain_path=domain, training_data_paths=training_files, config_path=config
)
# Process domain
domain = file_importer.get_domain()
intents = {i: 0 for i in domain.intents if i not in DEFAULT_INTENTS}
result["num_intents"] = len(intents.keys())
entities = {i: 0 for i in domain.entities}
result["num_entities"] = len(entities.keys())
actions = {i: 0 for i in domain.user_actions}
result["num_actions"] = len(actions.keys())
result["num_slots"] = len(domain.slots) - 1
result["num_responses"] = len(domain.responses.keys())
result["num_forms"] = len(domain.forms.keys())
result["num_retrieval_intents"] = len(domain.retrieval_intents)
# get's stories & rule
stories = file_importer.get_stories()
num_rule_steps = 0
num_story_steps = 0
for story in stories.story_steps:
if isinstance(story, RuleStep):
result['num_rules'] += 1
events = story.get_rules_events()
for evt in events:
if hasattr(evt, 'intent') or (hasattr(evt, 'action_name') and evt.action_name != '...') or (hasattr(evt, 'type_name') and evt.type_name == 'active_loop') or evt.type_name == 'slot':
num_rule_steps += 1
elif isinstance(story, StoryStep):
result['num_stories'] += 1
for evt in story.events:
if hasattr(evt, 'intent') or (hasattr(evt, 'action_name') and evt.action_name != '...'):
num_story_steps += 1
if result['num_rules']:
result['avg_rule_len'] = round(num_rule_steps / result['num_rules'], 1)
if result['num_stories']:
result['avg_story_len'] = round(num_story_steps / result['num_stories'], 1)
# nlu training data
nlu = file_importer.get_nlu_data()
if nlu.number_of_examples_per_intent:
result['avg_intent_examples'] = round(sum(nlu.number_of_examples_per_intent.values()) / len(nlu.number_of_examples_per_intent.values()), 1)
if nlu.number_of_examples_per_entity:
result['avg_entity_examples'] = round(sum(nlu.number_of_examples_per_entity.values()) / len(nlu.number_of_examples_per_entity.values()), 1)
return result
def calculate_test_coverage(data: DataType) -> float:
logging.debug(
"Calculating test coverage by inspecting domain and test conversations."
)
result = {
"num_intents": 0,
"num_entities": 0,
"num_actions": 0,
"num_slots": 0,
}
domain = load_domain()
intents = {i: 0 for i in domain.intents if i not in DEFAULT_INTENTS}
result["num_intents"] = len(intents.keys())
logging.debug(f"Found {len(intents)} intents in domain.")
entities = {i: 0 for i in domain.entities}
result["num_entities"] = len(entities.keys())
logging.debug(f"Found {len(entities)} entities in domain.")
actions = {i: 0 for i in domain.user_actions}
result["num_actions"] = len(actions.keys())
logging.debug(f"Found {len(actions)} actions in domain.")
result["num_slots"] = len(domain.slots) - 1
logging.debug(f"Found {len(actions)} actions in domain.")
data["scorecard"]["auto"]["training_data"].update(result)
try:
user_slots = domain._user_slots
except AttributeError:
user_slots = domain.slots
slots = {i.name: 0 for i in user_slots}
logging.debug(f"Found {len(slots)} slots in domain.")
tests_path = data["scorecard"]["manual"]["CI"]["test_dir"]
tmp_dir = get_data_files(tests_path, YAMLStoryReader.is_test_stories_file)
if (tmp_dir):
for test_file in tmp_dir:
# for test_file in os.listdir(tests_path):
with open(test_file, "r") as tempfile:
# with open(os.path.join(tmp_dir, test_file), "r") as tempfile:
for line in tempfile:
for intent in intents.keys():
if intent in line:
intents[intent] += 1
for entity in entities.keys():
if entity in line:
entities[entity] += 1
for action in actions.keys():
if action in line:
actions[action] += 1
for slot in slots.keys():
if slot in line:
slots[slot] += 1
logging.debug(f"Appearances per intent in test conversations: {intents}.")
covered_intents = [i for i, tot in intents.items() if tot > 1]
covered_entities = [i for i, tot in entities.items() if tot > 1]
covered_actions = [i for i, tot in actions.items() if tot > 1]
covered_slots = [i for i, tot in slots.items() if tot > 1]
logging.debug(
f"{len(covered_intents)}/{len(intents)} intents appear in at least 2 test"
" conversations."
)
logging.debug(
f"{len(covered_entities)}/{len(entities)} entities appear in at least 2 test"
" conversations."
)
logging.debug(
f"{len(covered_actions)}/{len(actions)} actions appear in at least 2 test"
" conversations."
)
logging.debug(
f"{len(covered_slots)}/{len(slots)} slots appear in at least 2 test"
" conversations."
)
numerator = (
len(covered_intents)
+ len(covered_entities)
+ len(covered_actions)
+ len(covered_slots)
)
denominator = len(intents) + len(entities) + len(actions) + len(slots)
if denominator == 0:
return 0
return round(100 * numerator / denominator, 1)
def update_data_from_repo(data: DataType) -> DataType:
"""Inspect repo and determine if necessary steps are in place."""
logging.debug("Inspecting contents of current directory.")
data["scorecard"]["auto"]["CI"]["code_and_training_data_in_git"] = _is_cwd_git_dir()
data["scorecard"]["auto"]["success_kpis"]["has_fallback"] = _has_fallback()
data["scorecard"]["auto"]["CI"]["rasa_test_coverage"] = calculate_test_coverage(
data
)
data["scorecard"]["auto"]["training_data"] = calculate_training_data(
data
)
return data
def update_data_from_rasa_enterprise(
data: DataType, token: str, url: str, crossval_results: Optional[str]
) -> DataType:
logging.debug("Fetching data from Rasa Enterprise API.")
if url is None or token is None:
return data
logging.debug("Successfully retrieved an auth token.")
total_conversations = count_conversations(url, token)
real_conversations = total_conversations - count_conversations(
url, token, rasa_channel=True
)
reviewed_conversations = total_conversations - count_conversations(
url, token, unread=True
)
automated_conversation_tags = data["scorecard"]["manual"]["success_kpis"][
"automated_conversation_tags"
]
tagged_conversations = 0
if automated_conversation_tags:
logging.debug(
f"Counting conversations tagged with one of {automated_conversation_tags}."
)
tagged_conversations = count_tagged_conversations(
url, token, tags=automated_conversation_tags
)
data["scorecard"]["auto"]["training_data_health"][
"num_real_conversations"
] = real_conversations
data["scorecard"]["auto"]["success_kpis"][
"num_reviewed_conversations"
] = reviewed_conversations
data["scorecard"]["auto"]["success_kpis"][
"num_tagged_conversations"
] = tagged_conversations
raw_num_messages = count_messages(url, token, exclude_training_data=False)
messages_not_in_training_data = count_messages(
url, token, exclude_training_data=True
)
num_annotated_messages = raw_num_messages - messages_not_in_training_data
data["scorecard"]["auto"]["training_data_health"][
"num_annotated_messages"
] = num_annotated_messages
warnings = count_warnings(url, token, crossval_results)
data["scorecard"]["auto"]["training_data_health"].update(warnings)
return data
def fill_boolean_question(
data: DataType, section: str, key: str, question: str
) -> DataType:
answer = questionary.confirm(question).ask()
data["scorecard"]["manual"][section][key] = answer
return data
def run_auth_wizard(
user: Optional[str], password: Optional[str], url: Optional[str]
) -> Tuple[Union[Any, str], ...]:
user = (
questionary.text("What is your Rasa Enterprise username?").ask()
if user is None
else user
)
password = (
questionary.password("What is your Rasa Enterprise password?").ask()
if password is None
else password
)
url = (
questionary.text("What is the URL of your Rasa Enterprise instance?").ask()
if url is None
else url
)
return user, password, url
def run_wizard() -> DataType:
should_create_new_file = questionary.confirm(
"Do you want to start a new scorecard from scratch?"
).ask()
if not should_create_new_file:
print_success(
"OK. Stopping the script. Please provide an existing score yaml file"
" through the `--file` argument"
)
exit(0)
print_success(
"Great! I'm going to ask a few questions to fill in your initial scorecard. The"
" scores evaluate the health of your training data, how well you are measuring"
" the success of your assistant, and the maturity of your CI/CD pipeline. This"
" is partly automated, but we'll have to ask some questions about your CI/CD"
" setup. You can always go back and update your answers by editing your"
" scorecard file."
)
data = read_yaml(DEFAULT_SCORECARD)
channels = ["website", "messenger", "twilio", "whatsapp", "custom"]
messaging_channels = questionary.checkbox(
"Please use the SPACE KEY to select the channels where your assistant is"
" available.",
choices=channels,
).ask()
data["scorecard"]["manual"]["training_data_health"][
"connected_channels"
] = messaging_channels
questions = {
"ci_runs_data_validation": (
"Does your CI server validate your data using `rasa data validate`?"
),
"ci_trains_model": (
"Does your CI server train a model for testing? E.g. with `rasa train`?"
),
"ci_runs_rasa_test": (
"Does your CI server run through test conversations? E.g. with `rasa test`?"
),
"ci_builds_action_server": (
"Does your CI server build an image for your custom action server?"
),
}
for key, question in questions.items():
data = fill_boolean_question(data, "CI", key, question)
test_dir = questionary.text(
"Please pick the directory where your test stories are saved.", default="tests"
).ask()
if os.path.isdir(test_dir):
data["scorecard"]["manual"]["CI"]["test_dir"] = test_dir
else:
logging.error(f"{test_dir} does not seem to be a valid directory. Skipping.")
questions = {
"ci_deploys_action_server": (
"Does your CI server automatically deploy your custom action server?"
),
"infrastructure_as_code": (
"Do you have all of your infrastructure configured as code? E.g. with"
" terraform."
),
"has_test_environment": (
"Do you have a test or staging environment to try changes before they go to"
" production?"
),
"ci_runs_vulnerability_scans": (
"Does your CI server run vulnerability scans on all images?"
),
}
for key, question in questions.items():
data = fill_boolean_question(data, "CD", key, question)
auto_tags = ""
try:
auto_tags = questionary.text(
"If you are using the Rasa Enterprise API to tag conversations"
" programatically, please enter the names of the relevant tags here (comma"
" separated)"
).ask()
except: # noqa: E722
print_error_and_exit(
"Your questionary version is not up to date"
" please update it by running `pip install -U questionary`"
)
if len(auto_tags) > 0:
auto_tags = [s.strip() for s in auto_tags.split(",")]
data["scorecard"]["manual"]["success_kpis"][
"automated_conversation_tags"
] = auto_tags
return data
def prompt_web_ui(url: str) -> None:
should_open_url = questionary.confirm(
"Would you like to open up your results in the web-based interface?"
).ask()
if should_open_url:
import webbrowser
webbrowser.open(url)
def get_shareable_link(yaml: str) -> str:
encoded = b64encode(yaml.encode("utf-8"))
string = encoded.decode("utf-8").replace("+", "-").replace("/", "_").rstrip("=")
url = urljoin(RASA_SCORECARD_URL, f"/?scorecard={string}")
return url
def generate_yaml(data) -> str:
from io import StringIO
schema_url = urljoin(RASA_SCORECARD_URL, "schema.json")
buffer = StringIO("")
write_yaml(data, buffer, should_preserve_key_order=True)
yaml = f"# yaml-language-server: $schema={schema_url}\n{buffer.getvalue()}"
return yaml
def load_scorecard(filename: str = None) -> str:
if not filename:
data = read_yaml(DEFAULT_SCORECARD)
logging.info("No scores file provided, using default values")
else:
data = read_existing_scores_from_file(filename)
yaml = generate_yaml(data)
return yaml
def update_scorecard(
skip_api: bool,
skip_wizard: bool,
filename: str,
user: str,
password: str,
url: str,
crossval_results: str,
):
token = None
if not skip_api:
user, password, url = run_auth_wizard(user=user, password=password, url=url)
token = get_auth_token(user, password, url)
if filename:
backup_filename = _create_backup_of_file(filename)
logging.info(f"backed up {filename} to {backup_filename}")
data = read_existing_scores_from_file(filename)
elif not skip_wizard:
data = run_wizard()
filename = "scores.yml"
logging.info(
f"Wizard complete. Creating a new scorecard file called {filename}"
)
else:
data = read_yaml(DEFAULT_SCORECARD)
filename = "scores.yml"
logging.info("Skipping wizard and taking default values.")
logging.info(
"updating scores based on current directory and data from the Rasa"
" Enterprise API"
)
data = update_scores(data, token, url, crossval_results)
yaml = generate_yaml(data)
with open(filename, "w") as fd:
fd.write(yaml)
print_success(f"Your scores have been updated and saved to {filename}.")
return yaml
def main() -> None:
arg_parser = _create_argument_parser()
cmdline_arguments = arg_parser.parse_args()
view = cmdline_arguments.view
filename = cmdline_arguments.file
user = cmdline_arguments.username
password = cmdline_arguments.password
url = cmdline_arguments.url
skip_api = cmdline_arguments.skip_api
skip_wizard = cmdline_arguments.skip_wizard
crossval_results = cmdline_arguments.on_the_fly_insights
log_level = logging.DEBUG if cmdline_arguments.debug else logging.INFO
logging.basicConfig(level=log_level)
if view:
yaml = load_scorecard(filename)
else:
yaml = update_scorecard(
skip_api, skip_wizard, filename, user, password, url, crossval_results
)
url = get_shareable_link(yaml)
print_info(
f"You can view your scorecard at:\n{url}"
"\nCopy the link above to share a view of your scorecard with others."
)
if view or not skip_wizard:
prompt_web_ui(url)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment