Last active
April 14, 2024 00:54
-
-
Save brycepg/bab4fd81ae1c93660754746c8ffb4d01 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from typing import Optional, List, Mapping, Any | |
from llama_index.core import SimpleDirectoryReader, SummaryIndex | |
from llama_index.core.callbacks import CallbackManager | |
from llama_index.core.llms import ( | |
CustomLLM, | |
CompletionResponse, | |
CompletionResponseGen, | |
LLMMetadata, | |
) | |
from llama_index.core.llms.callbacks import llm_completion_callback | |
from llama_index.core import Settings | |
import g4f | |
class OurLLM(CustomLLM): | |
context_window: int = 3900 | |
num_output: int = 256 | |
model_name: str = "g4f" | |
g4f_model: str = "gpt-4" | |
@property | |
def metadata(self) -> LLMMetadata: | |
"""Get LLM metadata.""" | |
return LLMMetadata( | |
context_window=self.context_window, | |
num_output=self.num_output, | |
model_name=self.model_name, | |
g4f_model=self.g4f_model, | |
) | |
@llm_completion_callback() | |
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
print("prompt", prompt) | |
print("kwargs", kwargs) | |
response = g4f.ChatCompletion.create( | |
model=self.g4f_model, | |
messages=[{"role": "user", "content": prompt}], | |
) | |
return CompletionResponse(text=response) | |
@llm_completion_callback() | |
def stream_complete( | |
self, prompt: str, **kwargs: Any | |
) -> CompletionResponseGen: | |
streaming_tokens = g4f.ChatCompletion.create( | |
model=self.g4f_model, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
response = "" | |
for token in streaming_tokens: | |
response += token | |
yield CompletionResponse(text=response, delta=token) | |
def get_provider(self): | |
return g4f.get_last_provider().__name__ | |
def cli_entrypoint(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("text", type=str, default="My response") | |
args = parser.parse_args() | |
main(args.text) | |
def main(text): | |
# define our LLM | |
print("Creating LLM") | |
llm = OurLLM(g4f_model="gpt-4") | |
# Load the your data | |
print("Generating response") | |
response = llm.stream_complete(text) | |
for token in response: | |
print(token.delta, end="") | |
print(f"Using {llm.get_provider()} provider") | |
if __name__ == "__main__": | |
cli_entrypoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment