Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created April 24, 2025 12:04
Show Gist options
  • Save ehzawad/ccb9ec669a2cb7d9f12c3978b88d4882 to your computer and use it in GitHub Desktop.
Save ehzawad/ccb9ec669a2cb7d9f12c3978b88d4882 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import asyncio
import aiohttp
import json
import time
import argparse
import os
import sys
import logging
from typing import Dict, Any, Optional, List, Text
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class RasaCLI:
def __init__(
self,
rasa_port: Optional[int] = None,
action_port: Optional[int] = None
):
"""Initialize Rasa client for interactive CLI."""
# Use default ports if not specified
self.rasa_port = rasa_port or 5005
self.action_port = action_port or 5054
# Configure base URLs with appropriate ports
self.rasa_base_url = f"http://localhost:{self.rasa_port}"
self.action_base_url = f"http://localhost:{self.action_port}"
self.via_number = "8809611888444" # Fixed via number
self.sender_id = None
# Test numbers for masking
self.test_numbers = [
'09696387582', '09638372914', '01924560627', '01518472623',
'01580582654', '01833626976', '01571321136', '01764655648',
'09638317055', '09638080760', '09696173224', '09611888444',
'01911310316', '19723182900', '01558666739', '01714007806',
'01714020387'
]
self.session = None
async def connect(self):
"""Create a hardened aiohttp session that resists broken pipes."""
if self.session is not None:
await self.close()
# Create a more resilient session with longer timeouts and connection hardening
timeout = aiohttp.ClientTimeout(sock_connect=10, sock_read=120, total=120)
connector = aiohttp.TCPConnector(
limit=20, # Sane concurrency
force_close=True, # Never reuse a keep-alive connection
enable_cleanup_closed=True, # Reap closed transports quickly
)
self.session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
)
async def close(self):
"""Close the aiohttp session."""
if self.session:
await self.session.close()
self.session = None
async def _request_with_retry(self, method: str, url: str, **kwargs):
"""Make a request with one retry on broken pipe error."""
if self.session is None:
await self.connect()
for attempt in (1, 2): # At most one retry
try:
async with self.session.request(method, url, **kwargs) as resp:
if resp.status >= 400:
logger.warning(f"HTTP {resp.status} from {method} {url}")
return await resp.json()
except aiohttp.ClientOSError as e:
if e.errno == 32 and attempt == 1: # Broken pipe, first attempt
logger.warning(f"Broken pipe → refreshing session & retrying {method} {url}...")
await self.connect() # Reconnect and retry
continue
raise # Re-raise on second failure or other error types
async def send_message(self, message_text: str) -> Dict[Any, Any]:
"""Send message to Rasa server and get response."""
try:
# Make sure we have a session
if self.session is None:
await self.connect()
# First get NLU parse result
parse_result = await self._request_with_retry(
"POST",
f"{self.rasa_base_url}/model/parse",
json={"text": message_text}
)
intent_info = parse_result.get('intent', {})
intent_name = intent_info.get('name', 'unknown')
intent_confidence = intent_info.get('confidence', 0.0)
entities = parse_result.get('entities', [])
# Send message to webhook
bot_response = await self._request_with_retry(
"POST",
f"{self.rasa_base_url}/webhooks/rest/webhook",
json={"sender": self.sender_id, "message": message_text}
)
# Get tracker state
tracker_data = await self._request_with_retry(
"GET",
f"{self.rasa_base_url}/conversations/{self.sender_id}/tracker"
)
# Get next action prediction
prediction = await self._request_with_retry(
"POST",
f"{self.rasa_base_url}/conversations/{self.sender_id}/predict"
)
next_action = prediction.get("scores", [{}])[0].get("action", "None")
confidence = prediction.get("scores", [{}])[0].get("score", 0.0)
return {
"response": bot_response,
"intent": {"name": intent_name, "confidence": intent_confidence},
"entities": entities,
"tracker_data": tracker_data,
"next_action": {
"name": next_action,
"confidence": confidence
}
}
except aiohttp.ClientError as e:
logger.error(f"Error communicating with Rasa server: {str(e)}")
return {
"response": [],
"intent": {"name": "unknown", "confidence": 0.0},
"entities": [],
"tracker_data": {},
"next_action": {"name": "None", "confidence": 0.0}
}
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return {
"response": [],
"intent": {"name": "unknown", "confidence": 0.0},
"entities": [],
"tracker_data": {},
"next_action": {"name": "None", "confidence": 0.0}
}
def get_bot_response_text(self, response: Dict[Any, Any]) -> str:
"""Extract text from bot response."""
if not response.get("response"):
return "No response from bot"
response_texts = []
for message in response["response"]:
if "text" in message:
response_texts.append(message["text"].replace('[service_response]', '').strip())
return response_texts[0] if response_texts else "No text response from bot"
def print_tracker_state(self, response_data: Dict[str, Any]):
"""Print detailed tracker state information."""
print("\n" + "="*40)
print("TRACKER STATE".center(40))
tracker_data = response_data.get("tracker_data", {})
# Active form
active_form = tracker_data.get("active_loop", {}).get("name")
print(f"Active Form: {active_form or 'None'}")
# Intent
intent_name = response_data["intent"]["name"]
intent_confidence = response_data["intent"]["confidence"]
print(f"Intent Name: {intent_name}")
print(f"Intent Confidence: {intent_confidence:.4f}")
# Slots
print("\nCurrent Slots:")
slots = tracker_data.get("slots", {})
filled_slots = False
for slot, value in slots.items():
if value: # Only print non-empty slots
print(f" - {slot}: {repr(value)}")
filled_slots = True
if not filled_slots:
print(" (No filled slots)")
# Actions
latest_action = tracker_data.get("latest_action_name", "None")
next_action = response_data.get("next_action", {})
print(f"\nLatest Action: {latest_action}")
print(f"Next Predicted Action: {next_action.get('name')} (confidence: {next_action.get('confidence', 0.0):.4f})")
# Recent Events - Improved presentation
print("\nRecent Events (Chronological):")
# Get a sufficient number of events to ensure we capture a complete exchange
events = tracker_data.get("events", [])[-10:] # Get last 10 events
# Group events for better readability
current_user_msg = None
current_action = None
bot_responses = []
for event in events:
event_type = event.get("event")
if event_type == "user":
# Print any pending actions and bot responses before showing the new user message
if current_action:
print(f" Action Executed: {current_action}")
current_action = None
if bot_responses:
for resp in bot_responses:
print(f" Bot: {repr(resp)}")
bot_responses = []
# Now print the user message
intent = event.get("parse_data", {}).get("intent", {}).get("name", "None")
current_user_msg = event.get('text')
print(f" User: {current_user_msg} → Intent: {intent}")
elif event_type == "bot":
bot_text = event.get('text')
if bot_text: # Only add non-empty messages
bot_responses.append(bot_text)
elif event_type == "action":
# If we get a new action, print any previous action first
if current_action and current_action != event.get('name'):
print(f" Action Executed: {current_action}")
current_action = event.get('name')
# If the action is action_listen and we have bot responses, print them
if current_action == "action_listen" and bot_responses:
for resp in bot_responses:
print(f" Bot: {repr(resp)}")
bot_responses = []
elif event_type == "slot":
print(f" Slot Set: {event.get('name')} = {repr(event.get('value'))}")
elif event_type == "active_loop":
status = "Started" if event.get("name") else "Stopped"
print(f" Active Loop: {event.get('name', 'None')} ({status})")
# Print any remaining actions or bot responses
if current_action:
print(f" Action Executed: {current_action}")
for resp in bot_responses:
print(f" Bot: {repr(resp)}")
print("="*40 + "\n")
async def check_server_health(self):
"""Check if Rasa servers are running."""
# Create a separate session just for health checks
timeout = aiohttp.ClientTimeout(total=10)
connector = aiohttp.TCPConnector(force_close=True)
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
try:
# Check Rasa server
try:
async with session.get(f"{self.rasa_base_url}/") as response:
rasa_status = response.status == 200
if not rasa_status:
logger.error(f"Rasa server health check failed: {response.status}")
except Exception as e:
logger.error(f"Rasa server check failed: {e}")
return False
# Check Action server
try:
async with session.get(f"{self.action_base_url}/health") as response:
action_status = response.status == 200
if not action_status:
logger.error(f"Action server health check failed: {response.status}")
except Exception as e:
logger.error(f"Action server check failed: {e}")
return False
return rasa_status and action_status
except Exception as e:
logger.error(f"Error checking server health: {e}")
return False
async def initialize_session(self):
"""Initialize a new session with user input."""
print("\n" + "="*50)
print("Welcome to Rasa Interactive CLI!".center(50))
print("="*50)
# Get session tuple from user
print("\nPlease enter your session information in this format:")
print("(session_id, phone_number, alternate_number)")
session_input = input("\n> ")
try:
# Parse the input format - handle spaces and parentheses carefully
session_input = session_input.strip()
# Remove parentheses if present
if session_input.startswith("(") and session_input.endswith(")"):
session_input = session_input[1:-1]
# Split by comma and strip whitespace
parts = [part.strip() for part in session_input.split(",")]
if len(parts) >= 3:
call_id = parts[0].strip()
phone_number = parts[1].strip()
alt_number = parts[2].strip()
# Handle number masking
masked_number = phone_number
if phone_number[-11:] in self.test_numbers:
masked_number = '01568725958'
# Use timestamp from NOW to generate the sender_id
current_time = int(time.time() * 1000)
self.sender_id = f"{current_time}_{self.via_number}_{masked_number}"
logger.info(f"Session initialized with sender_id: {self.sender_id}")
print("\nSession initialized successfully.")
print("Type your messages below. Type 'exit' to quit.")
print("User: ", end="", flush=True)
return True
else:
logger.error("Failed to parse session information - incorrect format")
print("\nError: Could not parse session information.")
print("Please use format: (session_id, phone_number, alternate_number)")
return False
except Exception as e:
logger.error(f"Error initializing session: {e}")
print(f"\nError initializing session: {e}")
return False
async def start_interactive_session(self):
"""Start an interactive session with the bot."""
try:
# Check if servers are running
is_healthy = await self.check_server_health()
if not is_healthy:
print(f"Error: Cannot connect to Rasa (port {self.rasa_port}) or Action (port {self.action_port}) server")
print("Please ensure both servers are running")
return
# Initialize session with user information
session_initialized = await self.initialize_session()
if not session_initialized:
print("Failed to initialize session. Exiting.")
return
# Initialize the connection once at the start
await self.connect()
# Main conversation loop
while True:
try:
# Get user input (we already printed the prompt)
user_input = input()
# Check for exit command
if user_input.lower() in ["exit", "quit", "bye"]:
print("Exiting Rasa Interactive CLI. Goodbye!")
break
# Send message and get response with full tracker info
response_data = await self.send_message(user_input)
# Print bot response
bot_text = self.get_bot_response_text(response_data)
print(f"Bot: {bot_text}")
# Print detailed tracker state
self.print_tracker_state(response_data)
# Print next user prompt
print("User: ", end="", flush=True)
except KeyboardInterrupt:
print("\nExiting due to keyboard interrupt.")
break
except aiohttp.ClientOSError as e:
if e.errno == 32: # Broken pipe
logger.warning(f"Connection reset (broken pipe). Reconnecting...")
await self.connect() # Force reconnection
print("User: ", end="", flush=True)
else:
logger.error(f"Connection error: {e}")
print(f"Connection error: {e}")
print("User: ", end="", flush=True)
except Exception as e:
logger.error(f"Error in message loop: {e}")
print(f"Error processing message: {e}")
print("User: ", end="", flush=True)
except Exception as e:
logger.error(f"Critical error in interactive session: {e}")
print(f"Critical error: {e}")
finally:
# Ensure we close the session
await self.close()
async def main():
parser = argparse.ArgumentParser(description="Interactive CLI for Rasa")
parser.add_argument('--rasa-port', type=int, default=5005, help='Rasa server port (default: 5005)')
parser.add_argument('--action-port', type=int, default=5054, help='Action server port (default: 5054)')
parser.add_argument('--debug', action='store_true', help='Enable debug output')
args = parser.parse_args()
# Set debug level if requested
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
# Create and start the CLI
cli = RasaCLI(rasa_port=args.rasa_port, action_port=args.action_port)
try:
await cli.start_interactive_session()
except Exception as e:
logger.error(f"Unhandled exception: {e}")
sys.exit(1)
finally:
await cli.close()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nProgram terminated by user")
except Exception as e:
print(f"Fatal error: {str(e)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment