Created
May 19, 2023 12:15
-
-
Save morganmcg1/a77b6494ef5888e902d6ad540a247672 to your computer and use it in GitHub Desktop.
Monkey patching Lanarky for WandbTracer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import langchain | |
import wandb | |
from typing import Any, Awaitable, Callable, Dict, Optional, Union | |
from fastapi.responses import StreamingResponse as _StreamingResponse | |
from langchain.chains.base import Chain | |
from starlette.background import BackgroundTask | |
from starlette.types import Send | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from langchain import ConversationChain | |
from langchain.chat_models import ChatOpenAI | |
from pydantic import BaseModel | |
from lanarky.responses import StreamingResponse | |
import uvicorn | |
from lanarky.callbacks import get_streaming_callback | |
from lanarky.responses import StreamingResponse | |
from datetime import datetime | |
from wandb.integration.langchain import WandbTracer | |
wandb_config = { | |
"project": "my-project", | |
# "entity" : "my-entity", | |
"tags" : ["hi wandb team"], | |
"name" : "my_log_" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"), | |
} | |
def new_create_chain_executor( | |
chain: Chain, | |
inputs: Union[Dict[str, Any], Any] | |
) -> Callable[[Send], Awaitable[Any]]: | |
print(f"Debug 1: Executing chain with inputs: {inputs}") # Debug print | |
async def wrapper(send: Send): | |
print(f"Debug 2: Executing chain with inputs: {inputs}") # Debug print | |
return await chain.acall( | |
inputs=inputs, | |
callbacks=[get_streaming_callback(chain, send=send), WandbTracer(wandb_config)] | |
) | |
return wrapper | |
# Monkey patching | |
StreamingResponse._create_chain_executor = staticmethod(new_create_chain_executor) | |
app = FastAPI() | |
@app.on_event("shutdown") | |
async def shutdown_event(): | |
WandbTracer.finish() | |
# Set up server endpoint | |
class Request(BaseModel): | |
query : str | |
import langchain | |
print(wandb.__version__) | |
print(langchain.__version__) | |
@app.post("/chat") | |
async def chat( | |
request: Request | |
): #-> StreamingResponse: | |
# set up chain taking in user query | |
from langchain import PromptTemplate | |
from langchain.chains import LLMChain | |
from langchain.agents import load_tools | |
from langchain.agents import initialize_agent | |
from langchain.agents import AgentType | |
from langchain.llms import OpenAI | |
# set up LLM with wandb tracer hooked up | |
llm = ChatOpenAI( | |
streaming=True, # must stream | |
openai_api_key="XXX" | |
) | |
# set up chain taking in user query | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=PromptTemplate.from_template(request.query) | |
) | |
response = StreamingResponse.from_chain( | |
llm_chain, # pass the chain | |
{}, # no input variables needed for this example | |
media_type="text/event-stream", # should stream | |
) | |
return response | |
if __name__ == "__main__": | |
uvicorn.run(app, host="127.0.0.1", port=8010) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment