Last active
March 2, 2024 02:19
-
-
Save thistleknot/138ec02bb1812c6be4915d960d325ce3 to your computer and use it in GitHub Desktop.
text-generation-webui extension - RAG google/duckduckgo search (async) w faiss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#for data txt files see: https://github.com/TheCynosure/smmry_impl | |
#example use | |
""" | |
Search_web("history of Taco Tuesday") | |
Tell me about this. | |
""" | |
#get google api keys' | |
#https://console.cloud.google.com/apis/dashboard | |
#https://programmablesearchengine.google.com/controlpanel/all | |
#could be retooled quite easily to use duckduckgo_search rather than google and you don't have to mess with getting api key's | |
#note, important parameter: num_sentences=n | |
import asyncio | |
import re | |
import requests | |
import gradio as gr | |
#import asyncio | |
from lxml import html | |
import aiohttp | |
import numpy as np | |
from aiohttp import ClientSession | |
import os | |
from dotenv import load_dotenv | |
import re | |
from bs4 import BeautifulSoup | |
import faiss | |
import pandas as pd | |
from modules import chat, shared | |
from modules.text_generation import ( | |
decode, | |
encode, | |
generate_reply, | |
) | |
from typing import Dict, List, Optional | |
from duckduckgo_search import DDGS | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from textblob import TextBlob | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Load API key and CSE ID from environment variables | |
load_dotenv() | |
api_key = os.getenv('api_key') | |
cse_id = os.getenv('cse_id') | |
params = { | |
"display_name": "Google Search Integration", | |
"is_tab": False, | |
} | |
class BSTNode: | |
def __init__(self, data: str, word_len: int): | |
self.data = data | |
self.word_len = word_len | |
self.score = 1 | |
self.left_child = None | |
self.right_child = None | |
class Bst: | |
def __init__(self): | |
self.root = None | |
def add(self, data: str, word_len: int) -> int: | |
curr = self.root | |
allocate = 1 | |
prev = None | |
last_dir = None | |
while curr is not None: | |
shorter_word_len = min(curr.word_len, word_len) | |
result = ord(data[0]) - ord(curr.data[0]) | |
if result == 0: | |
curr.score += 1 | |
allocate = 0 | |
break | |
elif result < 0: | |
prev = curr | |
curr = curr.left_child | |
last_dir = 'left' | |
else: | |
prev = curr | |
curr = curr.right_child | |
last_dir = 'right' | |
if allocate: | |
node = BSTNode(data, word_len) | |
if last_dir == 'left': | |
prev.left_child = node | |
elif last_dir == 'right': | |
prev.right_child = node | |
else: | |
self.root = node | |
return allocate | |
def get_score(self, data: str, word_len: int) -> int: | |
curr = self.root | |
while curr is not None: | |
shorter_word_len = min(curr.word_len, word_len) | |
result = ord(data[0]) - ord(curr.data[0]) | |
if result == 0: | |
return curr.score | |
elif result < 0: | |
curr = curr.left_child | |
else: | |
curr = curr.right_child | |
return 0 | |
def get_node(self, data: str, word_len: int) -> Optional[BSTNode]: | |
curr = self.root | |
while curr is not None: | |
shorter_word_len = min(curr.word_len, word_len) | |
result = ord(data[0]) - ord(curr.data[0]) | |
if result == 0: | |
return curr | |
elif result < 0: | |
curr = curr.left_child | |
else: | |
curr = curr.right_child | |
return None | |
def free_node(self, node: Optional[BSTNode]): | |
if node is not None: | |
self.free_node(node.left_child) | |
self.free_node(node.right_child) | |
node = None | |
def free_bst(self) -> None: | |
self.free_node(self.root) | |
self.root = None | |
class SentenceNode: | |
def __init__(self, data: str, score: int = 0): | |
self.data = data | |
self.link = None | |
self.score = score | |
class LinkedList: | |
def __init__(self): | |
self.head = None | |
self.tail = None | |
self.size = 0 | |
def insert(self, data: str): | |
node = SentenceNode(data) | |
if self.head is None: | |
self.head = node | |
else: | |
self.tail.link = node | |
self.tail = node | |
self.size += 1 | |
def free_list(self): | |
current_node = self.head | |
while current_node is not None: | |
temp = current_node.link | |
current_node.link = None | |
current_node = temp | |
self.head = None | |
self.tail = None | |
def strip_newlines_tabs(text_buffer: str) -> str: | |
return text_buffer.replace('\n', ' ').replace('\t', ' ') | |
def load_array(file_path: str) -> List[str]: | |
with open(file_path, 'r') as f: | |
lines = f.readlines() | |
return [line.strip() for line in lines] | |
def is_title(text_buffer: str, word_len: int, titles: List[str]) -> bool: | |
for title in titles: | |
if len(title) == word_len - 1 and text_buffer.startswith(title): | |
return True | |
return False | |
def sentence_chop(text_buffer: str, titles: List[str]) -> LinkedList: | |
sen_list = LinkedList() | |
current_word_len = 0 | |
new_sentence = False | |
last_sentence = 0 | |
inside_paren = False | |
for i, c in enumerate(text_buffer): | |
if c == ' ': | |
current_word_len = 0 | |
elif c == '.': | |
next_char = text_buffer[i+1] if i+1 < len(text_buffer) else None | |
if not inside_paren and ( | |
next_char == ' ' or | |
next_char == '\0' or | |
next_char == '(' or | |
next_char == '[') and not is_title(text_buffer[i - current_word_len:i], current_word_len, titles): | |
new_sentence = True | |
current_word_len = 0 | |
continue | |
else: | |
current_word_len += 1 | |
if c == '(': | |
inside_paren = True | |
elif c == ')': | |
inside_paren = False | |
if new_sentence: | |
sen_list.insert(text_buffer[last_sentence:i].strip()) | |
last_sentence = i | |
new_sentence = False | |
if new_sentence: | |
sen_list.insert(text_buffer[last_sentence:].strip()) | |
return sen_list | |
def load_list(file_path: str) -> Dict[str, str]: | |
word_dict = {} | |
with open(file_path, 'r') as f: | |
for line in f.readlines(): | |
synonym, base_word = line.strip().split() | |
base_word = base_word.lower() | |
word_dict[synonym] = base_word | |
return word_dict | |
def add_words_to_bst(word_bst: Bst, l: LinkedList, synonyms: Dict[str, str], irreg_nouns: Dict[str, str]): | |
current_node = l.head | |
while current_node is not None: | |
c = current_node.data | |
curr_word_start = 0 | |
curr_word_len = 0 | |
for i, char in enumerate(current_node.data): | |
if char == ' ': | |
score = word_bst.get_score(current_node.data[curr_word_start:curr_word_start + curr_word_len], curr_word_len) | |
if score > 1: | |
current_node.score += score | |
curr_word_start = i + 1 | |
curr_word_len = 0 | |
else: | |
if char.isalnum(): | |
curr_word_len += 1 | |
if curr_word_len > 0: | |
score = word_bst.get_score(current_node.data[curr_word_start:curr_word_start + curr_word_len], curr_word_len) | |
if score > 1: | |
current_node.score += score | |
current_node = current_node.link | |
def tally_top_scorers(word_bst: Bst, l: LinkedList, return_num: int, synonyms: Dict[str, str], irreg_nouns: Dict[str, str]) -> List[SentenceNode]: | |
top_scorers = [None] * return_num | |
current_node = l.head | |
while current_node is not None: | |
current_node.score = 0 | |
c = current_node.data | |
curr_word_start = 0 | |
curr_word_len = 0 | |
for i, char in enumerate(current_node.data): | |
if char == ' ': | |
score = word_bst.get_score(current_node.data[curr_word_start:curr_word_start + curr_word_len], curr_word_len) | |
if score > 1: | |
current_node.score += score | |
curr_word_start = i + 1 | |
curr_word_len = 0 | |
else: | |
curr_word_len += 1 | |
for i in range(return_num): | |
if top_scorers[i] is None or top_scorers[i].score < current_node.score: | |
top_scorers[i:] = [current_node] + top_scorers[i: -1] | |
break | |
current_node = current_node.link | |
return top_scorers | |
def get_rid_of_simples(word_bst: Bst, simples: List[str]): | |
for word in simples: | |
node = word_bst.get_node(word, len(word)) | |
if node is not None: | |
node.score = 0 | |
def summarize(text: str, num_sentences: int) -> list: | |
text_buffer = strip_newlines_tabs(text) | |
titles = load_array("/data/text-generation-webui/extensions/google_search/data/titles.txt") | |
l = sentence_chop(text_buffer, titles) | |
if num_sentences > l.size: | |
num_sentences = l.size | |
if num_sentences == 0: | |
return [] | |
synonyms = load_list("/data/text-generation-webui/extensions/google_search/data/formattedcommonsyns.txt") | |
irreg_nouns = load_list("/data/text-generation-webui/extensions/google_search/data/formattedirregnouns.txt") | |
word_bst = Bst() | |
add_words_to_bst(word_bst, l, synonyms, irreg_nouns) | |
simples = load_array("/data/text-generation-webui/extensions/google_search/data/simplewords.txt") | |
get_rid_of_simples(word_bst, simples) | |
top_scorers = tally_top_scorers(word_bst, l, num_sentences, synonyms, irreg_nouns) | |
results = [] | |
for node in top_scorers: | |
results.append(node.data) | |
return(results) | |
def dequote(s): | |
if (s[0] == s[-1]) and s.startswith(("'", '"')): | |
return s[1:-1] | |
return s | |
def google_search(query, api_key=api_key, cse_id=cse_id, **kwargs): | |
google_search_url = "https://www.googleapis.com/customsearch/v1" | |
params = { | |
'q': query, | |
'key': api_key, | |
'cx': cse_id | |
} | |
for k, v in kwargs.items(): | |
params[k] = v | |
response = requests.get(google_search_url, params=params) | |
search_results = response.json() | |
with DDGS() as ddgs: | |
results = [r for r in ddgs.text(query, max_results=10)] | |
urls = [r['href'] for r in results] | |
# Extract URLs from search results | |
urls.extend([item['link'] for item in search_results.get('items', [])]) | |
urls = np.unique(urls) | |
return(urls) | |
async def fetch_text(url): | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url) as response: | |
byte_content = await response.read() # Get the raw bytes of the response | |
try: | |
# First, attempt to decode as UTF-8 | |
content = byte_content.decode('utf-8') | |
except UnicodeDecodeError: | |
# If UTF-8 decoding fails, fall back to ISO-8859-1 | |
content = byte_content.decode('iso-8859-1') | |
# Now, proceed with BeautifulSoup parsing | |
soup = BeautifulSoup(content, 'lxml') | |
text_elements = soup.find_all('p') | |
texts = [element.text for element in text_elements] | |
return texts | |
async def get_texts(urls): | |
# Create a list of coroutine objects using list comprehension | |
tasks = [fetch_text(url) for url in urls] | |
# Use asyncio.gather to run the tasks concurrently | |
results = await asyncio.gather(*tasks) | |
# Flatten the results | |
#print(len(results)) | |
texts = [] | |
for sublist in results: | |
text = '\n'.join([text for text in sublist]) | |
texts.append(text) | |
#lens = [len(t) for t in texts] | |
lens = np.array([len(t) for t in texts]) | |
#remove outlier's | |
# Calculate quartiles and threshold for outlier detection | |
quartiles = np.quantile(lens, q=[0.25, 0.75]) | |
IQR = quartiles[1] - quartiles[0] | |
threshold1 = quartiles[1] + 1.5 * IQR | |
threshold2 = quartiles[0] - 1.5 * IQR | |
# Filter out outliers | |
#filtered_indices = np.where(lens <= threshold1)[0] # Get indices of non-outliers | |
filtered_indices = np.where((lens <= threshold1) & (lens >= threshold2))[0] | |
filtered_texts = np.array(texts)[filtered_indices] | |
filtered_lens = lens[filtered_indices] | |
return filtered_texts | |
def embed_text(examples): | |
inputs = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128, return_tensors='pt') | |
inputs = inputs.to(model.device) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
model.train() | |
pooled_embeds = mean_pooling(outputs, inputs["attention_mask"]) | |
return {"embedding": pooled_embeds.detach().cpu().numpy()} | |
def search_faiss_index(dataset, query, model, tokenizer, top_n=10): | |
query_input = {'text': query} | |
query_embedding = embed_text(query_input)['embedding'] | |
distances, results = dataset.get_nearest_examples('embedding', query_embedding, k=top_n) | |
return results['text'] | |
def build_faiss_index(texts): | |
df = pd.DataFrame(texts, columns=['text']) | |
dataset = Dataset.from_pandas(df) | |
dataset = dataset.map(embed_text, batched=True, batch_size=16) | |
dataset.add_faiss_index("embedding") | |
return dataset | |
def get_reponse(prompt, max_tokens, url = "http://127.0.0.1:5000/v1/completions", headers = {"Content-Type": "application/json"}): | |
data = { | |
'prompt': prompt, | |
'max_new_tokens': max_tokens, | |
'preset': 'None', | |
'do_sample': True, | |
'temperature': 0.7, | |
'top_p': 0.1, | |
'typical_p': 1, | |
'epsilon_cutoff': 0, # In units of 1e-4 | |
'eta_cutoff': 0, # In units of 1e-4 | |
'tfs': 1, | |
'top_a': 0, | |
'repetition_penalty': 1.18, | |
'repetition_penalty_range': 0, | |
'top_k': 40, | |
'min_length': 0, | |
'no_repeat_ngram_size': 0, | |
'num_beams': 1, | |
'penalty_alpha': 0, | |
'length_penalty': 1, | |
'early_stopping': True, | |
'mirostat_mode': 0, | |
'mirostat_tau': 5, | |
'mirostat_eta': 0.1, | |
'seed': -1, | |
'add_bos_token': True, | |
'truncation_length': max_tokens, | |
'ban_eos_token': False, | |
'skip_special_tokens': True, | |
'stopping_strings': [] | |
} | |
# Make the POST request | |
response = requests.post(url, headers=headers, json=data) | |
if response.status_code == 200: | |
# Assuming the API returns JSON with a key that contains the generated text | |
# You might need to adjust this depending on the actual response structure | |
response_data = response.json() | |
# Assuming the response JSON structure includes the generated phrase under keys like 'choices' | |
# Adjust the key names based on your actual response structure | |
response = response_data['choices'][0]['text'].strip() | |
return response | |
else: | |
print("Failed to obtain search phrase from the API.") | |
print("Status Code:", response.status_code) | |
print("Response:", response.text) | |
return None | |
def get_search_phrase(query): | |
# API endpoint | |
url = "http://127.0.0.1:5000/v1/completions" | |
# Headers to indicate JSON content | |
headers = { | |
"Content-Type": "application/json" | |
} | |
# Data payload with your specified prompt structure | |
data = { | |
"prompt": f""" | |
Instruction: | |
Provide a Google search phrase that efficiently finds information relevant to the instruction: '{query}' | |
Once you provide a succinct search phrase, do not provide any further response. | |
Response: | |
Phrase: | |
""", | |
"max_tokens": 200, | |
"temperature": 1, | |
"top_p": 0.9, | |
"seed": 10 | |
} | |
# Make the POST request | |
response = requests.post(url, headers=headers, json=data) | |
if response.status_code == 200: | |
# Assuming the API returns JSON with a key that contains the generated text | |
# You might need to adjust this depending on the actual response structure | |
response_data = response.json() | |
# Assuming the response JSON structure includes the generated phrase under keys like 'choices' | |
# Adjust the key names based on your actual response structure | |
generated_phrase = response_data['choices'][0]['text'].strip() | |
return generated_phrase | |
else: | |
print("Failed to obtain search phrase from the API.") | |
print("Status Code:", response.status_code) | |
print("Response:", response.text) | |
return None | |
def process_search_results(texts, query, search_phrase, num=13): | |
# API endpoint | |
url = "http://127.0.0.1:5000/v1/completions" | |
# Headers to indicate JSON content | |
headers = { | |
"Content-Type": "application/json" | |
} | |
# Data payload with your specified prompt structure | |
data = { | |
"prompt": f""" | |
Context: | |
{texts} | |
Instruction: | |
Given the search phrase '{search_phrase}' to support the user query '{query}', the following search results were extracted. | |
Not all search results are relevant. From the result set, extract the most poignant {num} sentences relevant to the user's query listing them from most relevant to least: | |
Response: | |
""", | |
"max_tokens": 4096, | |
"temperature": 1, | |
"top_p": 0.9, | |
"seed": 10 | |
} | |
# Make the POST request | |
response = requests.post(url, headers=headers, json=data) | |
if response.status_code == 200: | |
# Assuming the API returns JSON with a key that contains the generated text | |
# You might need to adjust this depending on the actual response structure | |
response_data = response.json() | |
# Assuming the response JSON structure includes the generated phrase under keys like 'choices' | |
# Adjust the key names based on your actual response structure | |
generated_phrase = response_data['choices'][0]['text'].strip() | |
return generated_phrase | |
else: | |
print("Failed to obtain results from the API.") | |
print("Status Code:", response.status_code) | |
print("Response:", response.text) | |
return None | |
async def get_search_results(urls, query, search_phrase, num=55): | |
texts = await get_texts(urls) # Assuming get_texts is also adjusted for async | |
#[text for sublist in results for text in sublist] | |
#model = SentenceTransformer('all-MiniLM-L6-v2') | |
queries = [query, search_phrase] | |
query_embeddings = model.encode(queries, convert_to_tensor=False) | |
query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True) | |
query_embeddings = query_embeddings.astype(np.float32) | |
result_set = [] | |
for flattened_texts in texts: | |
#flattened_texts = [text for sublist in text_ for text in sublist] | |
print('len flattened_texts',len(flattened_texts)) | |
#extracted_text = '\n'.join(flattened_texts) | |
#print("text_",extracted_text) | |
# Step 1: Extract sentences | |
blob = TextBlob(flattened_texts) | |
sentences = [sentence.raw for sentence in blob.sentences] | |
# Step 2: Generate embeddings for sentences | |
embeddings = model.encode(sentences, convert_to_tensor=False) | |
# Normalize embeddings for cosine similarity | |
try: | |
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
# Step 3: Indexing with FAISS | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dim) # Use IndexFlatIP for cosine similarity | |
index.add(embeddings.astype(np.float32)) | |
# Prepare embeddings for both query and search_phrase, normalized | |
# Step 4: Querying for both query and search_phrase | |
#median | |
k = int(np.round(len(sentences)/2)) # Number of sentences to return for each query | |
D, I = index.search(query_embeddings, k) # D: distances, I: indices | |
# Filter based on a similarity threshold for either query or search_phrase | |
threshold = 0 # Example threshold | |
filtered_indices = [] # Use a list to collect indices | |
for distances, indices in zip(D, I): | |
# Collect indices where the distance meets or exceeds the threshold | |
filtered_indices.extend([i for i, d in zip(indices, distances) if d >= threshold]) | |
# Remove duplicates and ensure a sorted order | |
filtered_indices = sorted(set(filtered_indices)) | |
# Debugging print statements to confirm the structure and contents | |
print("Filtered Indices:", filtered_indices) | |
print("Type of Filtered Indices:", type(filtered_indices)) | |
print("Length of sentences list:", len(sentences)) | |
# Use the filtered indices to select the corresponding sentences | |
result_sentences = [sentences[i] for i in filtered_indices] | |
#result_sentences_text = '\n'.join(result_sentences) | |
result_set.append(result_sentences) | |
except: | |
#print("Error processing sentences", sentences) | |
pass | |
#print("result_set",result_set) | |
filtered_lens = [len(t) for t in result_set] | |
print('lens', filtered_lens) | |
total_length = np.sum(filtered_lens) | |
print('total_length', total_length) | |
ratios = [l/total_length for l in filtered_lens] | |
rations = [] | |
for r in ratios: | |
rations.append(int(np.round(r*num))) | |
print('ratios',rations) | |
summarized = [summarize('\n'.join(text), num_sentences=int(ratio)) for text, ratio in zip(result_set, rations) if ratio > 0] | |
summarized_strings = ['\n'.join(item) for item in summarized if item] | |
# Now join these strings into a single summary text | |
summary_text = '\n\n'.join(summarized_strings) | |
# Replace consecutive spaces if needed | |
summary_text = summary_text.replace(' ', ' ').replace(' ', '') | |
#print("summary_text", summary_text) | |
return summary_text | |
def history_modifier(history): | |
""" | |
Modifies the chat history. | |
Only used in chat mode. | |
""" | |
return history | |
def state_modifier(state): | |
""" | |
Modifies the state variable, which is a dictionary containing the input | |
values in the UI like sliders and checkboxes. | |
""" | |
return state | |
def chat_input_modifier(text, visible_text, state): | |
""" | |
Modifies the user input string in chat mode (visible_text). | |
You can also modify the internal representation of the user | |
input (text) to change how it will appear in the prompt. | |
""" | |
match = re.search(r'Search_web\(["\'](.*)["\']\)', text) | |
if match: | |
print('search requested') | |
query = match.group(1) # Extract the search query from the command | |
initial_response = get_reponse(query, max_tokens=2048) | |
search_phrase = get_search_phrase(query) | |
search_phrase = dequote(search_phrase.split('\n')[0]) | |
print("Derived search phrase:", search_phrase) | |
# Extract URLs from search results | |
urls = google_search(search_phrase) | |
print(urls) | |
#55 | |
search_results = run_asyncio_coroutine(get_search_results(urls, query, search_phrase)) | |
#print("search_results\n\n", search_results) | |
#13 | |
filtered_summary = process_search_results(search_results, query, search_phrase, num=13) | |
print("filtered_summary",filtered_summary) | |
# Append the search summary to the user input | |
if filtered_summary: | |
context = f'Start Internal Monologue: What you currently know about the subject:\n\n{initial_response}\n\n.Web search results for search phrase \'{search_phrase}\', given query \'{query}\':\n\n{filtered_summary}\n\n' | |
text = context + 'End Web search results. End Internal Monologue:\n\nSynthesize a comprehensive holistic response using the internal monologue to support the request, be sure to be as comprehensive as possible and include all relevant sources.\n\nRequest:\n\n' + text | |
else: | |
pass | |
return text, visible_text | |
def input_modifier(string, state, is_chat=False): | |
""" | |
In default/notebook modes, modifies the whole prompt. | |
In chat mode, it is the same as chat_input_modifier but only applied | |
to "text", here called "string", and not to "visible_text". | |
""" | |
return string | |
def bot_prefix_modifier(string, state): | |
""" | |
Modifies the prefix for the next bot reply in chat mode. | |
By default, the prefix will be something like "Bot Name:". | |
""" | |
return string | |
def tokenizer_modifier(state, prompt, input_ids, input_embeds): | |
""" | |
Modifies the input ids and embeds. | |
Used by the multimodal extension to put image embeddings in the prompt. | |
Only used by loaders that use the transformers library for sampling. | |
""" | |
return prompt, input_ids, input_embeds | |
#could use drugs to modify internal representations of the model | |
def logits_processor_modifier(processor_list, input_ids): | |
""" | |
Adds logits processors to the list, allowing you to access and modify | |
the next token probabilities. | |
Only used by loaders that use the transformers library for sampling. | |
""" | |
processor_list.append(MyLogits()) | |
return processor_list | |
def output_modifier(string, state, is_chat=True): | |
""" | |
Modifies the LLM output before it gets presented. | |
In chat mode, the modified version goes into history['visible'], | |
and the original version goes into history['internal']. | |
""" | |
return string | |
def custom_generate_chat_prompt(user_input, state, **kwargs): | |
""" | |
Replaces the function that generates the prompt from the chat history. | |
Only used in chat mode. | |
""" | |
result = chat.generate_chat_prompt(user_input, state, **kwargs) | |
return result | |
def run_asyncio_coroutine(coroutine): | |
try: | |
# Try to get the existing event loop. | |
loop = asyncio.get_event_loop() | |
except RuntimeError as e: | |
# If there's no existing loop, create a new one. | |
# This is more relevant for environments like Python scripts and interactive shells. | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
if loop.is_running(): | |
# If the loop is already running, you might be in a context (like Jupyter or some frameworks) | |
# that doesn't allow straightforward launching of new tasks. This requires careful handling, | |
# and solutions might vary based on the specific context. | |
raise RuntimeError("The event loop is already running. This setup requires an adjustment to handle async tasks.") | |
else: | |
# If the loop is not running, you can run the coroutine to completion. | |
return loop.run_until_complete(coroutine) | |
def setup(): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment