Created
October 22, 2024 18:44
-
-
Save matthewhand/51d73222a6e88120f07dd8a202fa98a2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| title: Unified Pipe for Flowise and Ollama with Corrected Ollama Endpoint and Response Parsing | |
| version: 1.0 | |
| """ | |
| from typing import Optional, Callable, Awaitable, Dict, Any, List, Union, Generator | |
| import aiohttp | |
| import json | |
| import time | |
| import asyncio | |
| from pydantic import BaseModel, Field | |
| class Pipe: | |
| """ | |
| Unified Pipeline for managing interactions with both Flowise and Ollama APIs. | |
| - Routes summarization requests to Ollama based on a predefined prefix. | |
| - Routes all other requests to Flowise. | |
| Requires: | |
| - Flowise service: https://github.com/FlowiseAI/Flowise | |
| - Ollama service: https://ollama.com/ | |
| NOTE: | |
| - Ensure Flowise is deployed on a different port (e.g., 3030) to avoid conflicts. | |
| """ | |
| class Valves(BaseModel): | |
| # Configuration for Flowise | |
| FLOWISE_API_ENDPOINT: str = Field( | |
| default="http://host.docker.internal:3030/", | |
| description="Base URL for the Flowise API endpoint.", | |
| ) | |
| FLOWISE_USERNAME: Optional[str] = Field( | |
| default=None, description="Username for Flowise API auth." | |
| ) | |
| FLOWISE_PASSWORD: Optional[str] = Field( | |
| default=None, description="Password for Flowise API auth." | |
| ) | |
| FLOWISE_CHATFLOW_ID: str = Field( | |
| default="", description="Chatflow ID for the Flowise API." | |
| ) | |
| # Configuration for Ollama | |
| OLLAMA_API_ENDPOINT: str = Field( | |
| default="http://host.docker.internal:11435", | |
| description="Base URL for the Ollama API endpoint.", | |
| ) | |
| OLLAMA_API_KEY: Optional[str] = Field( | |
| default=None, description="API key for Ollama API auth." | |
| ) | |
| OLLAMA_MODEL_ID: str = Field( | |
| default="llama3.2", description="Model ID for the Ollama API." | |
| ) | |
| # Summarization Valve | |
| SUMMARIZATION_PROMPT_PREFIX: str = Field( | |
| default="Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language.", | |
| description="Prefix that identifies a summarization request.", | |
| ) | |
| # Common Settings | |
| emit_interval: float = Field( | |
| default=1.0, description="Interval between status emissions." | |
| ) | |
| enable_status_indicator: bool = Field( | |
| default=True, description="Enable/disable status indicator." | |
| ) | |
| request_timeout: int = Field( | |
| default=300, description="HTTP client timeout in seconds." | |
| ) | |
| debug: bool = Field( | |
| default=True, description="Enable or disable debug logging." | |
| ) | |
| def __init__(self): | |
| self.valves = self.Valves() | |
| self.stop_emitter = asyncio.Event() | |
| self.chat_sessions = ( | |
| {} | |
| ) # Store chat sessions {user_id: {"chat_id": ..., "history": [...] }} | |
| def log(self, message: str): | |
| """Logs a message if debugging is enabled.""" | |
| if self.valves.debug: | |
| print(f"[DEBUG] {message}") | |
| def clean_response_text(self, text: str) -> str: | |
| """ | |
| Removes unnecessary surrounding quotes from the response. | |
| Handles cases where the response text may be wrapped in quotes. | |
| """ | |
| self.log(f"Original text before cleaning: {text!r}") | |
| # Remove outer quotes if present | |
| while text.startswith('"') and text.endswith('"'): | |
| text = text[1:-1] | |
| self.log(f"Text after stripping quotes: {text!r}") | |
| cleaned_text = text.strip() | |
| self.log(f"Final cleaned text: {cleaned_text!r}") | |
| return cleaned_text | |
| def _get_latest_user_message(self, messages: List[Dict[str, str]]) -> Optional[str]: | |
| """ | |
| Extracts the latest user message content from the messages list. | |
| Strips any "User: " prefix if present. | |
| """ | |
| for message in reversed(messages): | |
| if message.get("role") == "user": | |
| content = message.get("content", "").strip() | |
| if content.startswith("User: "): | |
| content = content[len("User: ") :].strip() | |
| self.log(f"Stripped 'User: ' prefix. Content: {content}") | |
| else: | |
| self.log(f"No 'User: ' prefix found. Content: {content}") | |
| if content: | |
| self.log(f"Latest user question extracted: {content}") | |
| return content | |
| self.log("No user message found in the messages.") | |
| return None | |
| def _get_combined_prompt(self, messages: List[Dict[str, str]]) -> str: | |
| """ | |
| Combines user and assistant messages into a structured prompt. | |
| Example: | |
| User: hi | |
| Assistant: How can I assist you today? | |
| User: 5 words to describe ai | |
| """ | |
| prompt_parts = [ | |
| f"{message.get('role', 'user').capitalize()}: {message.get('content', '')}" | |
| for message in messages | |
| ] | |
| combined_prompt = "\n".join(prompt_parts) | |
| self.log(f"Combined prompt:\n{combined_prompt}") | |
| return combined_prompt | |
| async def emit_periodic_status( | |
| self, | |
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]], | |
| message: str, | |
| interval: float, | |
| ): | |
| """Periodically emit status updates.""" | |
| start_time = time.time() | |
| try: | |
| while not self.stop_emitter.is_set(): | |
| elapsed_time = time.time() - start_time | |
| await self.emit_status( | |
| __event_emitter__, | |
| "info", | |
| f"{message} (elapsed: {elapsed_time:.1f}s)", | |
| False, | |
| ) | |
| await asyncio.sleep(interval) | |
| except asyncio.CancelledError: | |
| self.log("Periodic status emission cancelled.") | |
| async def emit_status( | |
| self, | |
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]], | |
| level: str, | |
| message: str, | |
| done: bool, | |
| ): | |
| """Emit status events.""" | |
| if callable(__event_emitter__): | |
| event = {"type": "status", "data": {"description": message, "done": done}} | |
| self.log(f"Emitting status event: {event}") | |
| await __event_emitter__(event) | |
| else: | |
| self.log("No valid event emitter provided. Skipping event emission.") | |
| async def pipe( | |
| self, | |
| body: dict, | |
| __user__: Optional[dict] = None, | |
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]] = None, | |
| ) -> Union[str, Generator[str, None, None]]: | |
| """ | |
| Main pipe method to handle routing. | |
| - Routes summarization requests to Ollama. | |
| - Routes all other requests to Flowise. | |
| """ | |
| status_task = None | |
| start_time = time.time() | |
| try: | |
| # Emit periodic status if enabled | |
| if callable(__event_emitter__) and self.valves.enable_status_indicator: | |
| self.log("Starting periodic status emitter...") | |
| self.stop_emitter.clear() | |
| status_task = asyncio.create_task( | |
| self.emit_periodic_status( | |
| __event_emitter__, | |
| "Processing request...", | |
| self.valves.emit_interval, | |
| ) | |
| ) | |
| else: | |
| self.log("No valid event emitter provided. Skipping periodic status.") | |
| # Extract messages and create the prompt | |
| messages = body.get("messages", []) | |
| self.log(f"Messages extracted: {messages}") | |
| if not messages: | |
| self.log("No messages found in the request body.") | |
| return "Error: No messages found." | |
| prompt = self._get_combined_prompt(messages) | |
| self.log(f"Prompt prepared: {prompt}") | |
| # Extract the latest user message | |
| latest_user_question = self._get_latest_user_message(messages) | |
| if not latest_user_question: | |
| self.log("No user message found in the messages.") | |
| return "Error: No user message found." | |
| # Determine if the request is for summarization | |
| if self.is_summarization_request(latest_user_question): | |
| self.log("Summarization request detected. Routing to Ollama.") | |
| # Delegate to Ollama API | |
| response = await self.handle_ollama_request( | |
| latest_user_question, __user__, __event_emitter__ | |
| ) | |
| else: | |
| self.log("Regular request detected. Routing to Flowise.") | |
| # Delegate to Flowise API | |
| response = await self.handle_flowise_request( | |
| latest_user_question, __user__, __event_emitter__ | |
| ) | |
| # Emit final status | |
| elapsed_time = time.time() - start_time | |
| await self.emit_status( | |
| __event_emitter__, | |
| "info", | |
| f"Pipe Completed in {elapsed_time:.1f}s", | |
| True, | |
| ) | |
| return response | |
| except Exception as e: | |
| self.log(f"Error during pipe execution: {str(e)}") | |
| return f"Error: {e}" | |
| finally: | |
| if status_task: | |
| self.stop_emitter.set() | |
| await status_task | |
| def is_summarization_request(self, question: str) -> bool: | |
| """ | |
| Determine if the request is for summarization based on the 'question' field. | |
| Checks if the question starts with the predefined summarization prompt prefix. | |
| """ | |
| is_match = question.startswith(self.valves.SUMMARIZATION_PROMPT_PREFIX) | |
| self.log(f"Is summarization request (starts with prefix): {is_match}") | |
| return is_match | |
| async def handle_ollama_request( | |
| self, | |
| question: str, | |
| __user__: Optional[dict], | |
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]], | |
| ) -> Union[str, Generator[str, None, None]]: | |
| """ | |
| Handles summarization requests by sending them to Ollama. | |
| """ | |
| try: | |
| # Prepare the payload for Ollama | |
| payload = { | |
| "prompt": question, | |
| "model": self.valves.OLLAMA_MODEL_ID, # Ensure this matches Ollama's API requirements | |
| # Add other necessary fields as per Ollama's API requirements | |
| } | |
| self.log(f"Payload for Ollama: {payload}") | |
| # Construct the full Ollama API URL by appending '/v1/completions' | |
| url = f"{self.valves.OLLAMA_API_ENDPOINT.rstrip('/')}/v1/completions" | |
| headers = {"Content-Type": "application/json"} | |
| # Handle authentication if provided | |
| if self.valves.OLLAMA_API_KEY: | |
| headers["Authorization"] = f"Bearer {self.valves.OLLAMA_API_KEY}" | |
| self.log("Ollama authentication enabled.") | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=self.valves.request_timeout) | |
| ) as session: | |
| async with session.post(url, json=payload, headers=headers) as response: | |
| response_text = await response.text() | |
| self.log(f"Ollama response status: {response.status}") | |
| self.log(f"Ollama response text: {response_text}") | |
| if response.status != 200: | |
| self.log( | |
| f"Ollama API call failed with status: {response.status}" | |
| ) | |
| return f"Error: Ollama API call failed with status {response.status}" | |
| # Extract and clean the response text | |
| try: | |
| data = json.loads(response_text) | |
| self.log(f"Parsed Ollama response data: {data}") | |
| except json.JSONDecodeError: | |
| self.log("Failed to decode JSON from Ollama's response.") | |
| return "Error: Invalid JSON response from Ollama." | |
| # Extract the generated text from 'choices' | |
| choices = data.get("choices", []) | |
| if not choices: | |
| self.log("No choices found in Ollama's response.") | |
| return "Error: No choices found in Ollama's response." | |
| first_choice = choices[0] | |
| raw_text = first_choice.get("text", "") | |
| self.log(f"Raw text from first choice: {raw_text!r}") | |
| text = self.clean_response_text(raw_text) | |
| if not text: | |
| self.log("No valid text found in Ollama's response.") | |
| return "Error: Empty response from Ollama." | |
| self.log(f"Extracted text from Ollama: {text!r}") | |
| # Optionally, update chat session if needed | |
| # For example, if Ollama returns a new chat_id | |
| new_chat_id = data.get("chat_id") | |
| if new_chat_id: | |
| user_id = ( | |
| __user__.get("user_id", "default_user") | |
| if __user__ | |
| else "default_user" | |
| ) | |
| if user_id not in self.chat_sessions: | |
| self.chat_sessions[user_id] = { | |
| "chat_id": None, | |
| "history": [], | |
| } | |
| self.chat_sessions[user_id]["chat_id"] = new_chat_id | |
| return text | |
| except Exception as e: | |
| self.log(f"Error during Ollama request handling: {str(e)}") | |
| return f"Error: {e}" | |
| async def handle_flowise_request( | |
| self, | |
| question: str, | |
| __user__: Optional[dict], | |
| __event_emitter__: Optional[Callable[[dict], Awaitable[None]]], | |
| ) -> Union[str, Generator[str, None, None]]: | |
| """ | |
| Handles regular requests by sending them to Flowise. | |
| """ | |
| try: | |
| # Prepare the payload for Flowise | |
| payload = {"question": question} | |
| # Include chatId if it exists in the session | |
| user_id = ( | |
| __user__.get("user_id", "default_user") if __user__ else "default_user" | |
| ) | |
| chat_session = self.chat_sessions.get(user_id, {}) | |
| chat_id = chat_session.get("chat_id") | |
| if chat_id: | |
| payload["chatId"] = chat_id | |
| self.log(f"Payload for Flowise: {payload}") | |
| # Send the request to Flowise API | |
| endpoint = self.valves.FLOWISE_API_ENDPOINT.rstrip("/") | |
| url = f"{endpoint}/api/v1/prediction/{self.valves.FLOWISE_CHATFLOW_ID}" | |
| headers = {"Content-Type": "application/json"} | |
| # Handle authentication if provided | |
| auth = None | |
| if self.valves.FLOWISE_USERNAME and self.valves.FLOWISE_PASSWORD: | |
| auth = aiohttp.BasicAuth( | |
| self.valves.FLOWISE_USERNAME, self.valves.FLOWISE_PASSWORD | |
| ) | |
| self.log("Flowise authentication enabled.") | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(total=self.valves.request_timeout), | |
| auth=auth, | |
| ) as session: | |
| async with session.post(url, json=payload, headers=headers) as response: | |
| response_text = await response.text() | |
| self.log(f"Flowise response status: {response.status}") | |
| self.log(f"Flowise response text: {response_text}") | |
| if response.status != 200: | |
| self.log( | |
| f"Flowise API call failed with status: {response.status}" | |
| ) | |
| return f"Error: Flowise API call failed with status {response.status}" | |
| # Extract and clean the response text | |
| try: | |
| data = json.loads(response_text) | |
| self.log(f"Parsed Flowise response data: {data}") | |
| except json.JSONDecodeError: | |
| self.log("Failed to decode JSON from Flowise's response.") | |
| return "Error: Invalid JSON response from Flowise." | |
| raw_text = data.get("text", "") | |
| text = self.clean_response_text(raw_text) | |
| new_chat_id = data.get("chatId", chat_id) | |
| if not text: | |
| self.log("No valid text found in Flowise's response.") | |
| return "Error: Empty response from Flowise." | |
| self.log(f"Extracted text from Flowise: {text!r}") | |
| self.log(f"New chat ID from Flowise: {new_chat_id}") | |
| # Update chat session | |
| if user_id not in self.chat_sessions: | |
| self.chat_sessions[user_id] = {"chat_id": None, "history": []} | |
| self.chat_sessions[user_id]["chat_id"] = new_chat_id | |
| # Append to chat history | |
| self.chat_sessions[user_id]["history"].append( | |
| {"role": "assistant", "content": text} | |
| ) | |
| return text | |
| except Exception as e: | |
| self.log(f"Error during Flowise request handling: {str(e)}") | |
| return f"Error: {e}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Good afternoon! Please write how to use this script?