Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created May 19, 2023 12:15
Show Gist options
  • Save morganmcg1/a77b6494ef5888e902d6ad540a247672 to your computer and use it in GitHub Desktop.
Save morganmcg1/a77b6494ef5888e902d6ad540a247672 to your computer and use it in GitHub Desktop.
Monkey patching Lanarky for WandbTracer
import langchain
import wandb
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from fastapi.responses import StreamingResponse as _StreamingResponse
from langchain.chains.base import Chain
from starlette.background import BackgroundTask
from starlette.types import Send
from dotenv import load_dotenv
from fastapi import FastAPI
from langchain import ConversationChain
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel
from lanarky.responses import StreamingResponse
import uvicorn
from lanarky.callbacks import get_streaming_callback
from lanarky.responses import StreamingResponse
from datetime import datetime
from wandb.integration.langchain import WandbTracer
wandb_config = {
"project": "my-project",
# "entity" : "my-entity",
"tags" : ["hi wandb team"],
"name" : "my_log_" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
}
def new_create_chain_executor(
chain: Chain,
inputs: Union[Dict[str, Any], Any]
) -> Callable[[Send], Awaitable[Any]]:
print(f"Debug 1: Executing chain with inputs: {inputs}") # Debug print
async def wrapper(send: Send):
print(f"Debug 2: Executing chain with inputs: {inputs}") # Debug print
return await chain.acall(
inputs=inputs,
callbacks=[get_streaming_callback(chain, send=send), WandbTracer(wandb_config)]
)
return wrapper
# Monkey patching
StreamingResponse._create_chain_executor = staticmethod(new_create_chain_executor)
app = FastAPI()
@app.on_event("shutdown")
async def shutdown_event():
WandbTracer.finish()
# Set up server endpoint
class Request(BaseModel):
query : str
import langchain
print(wandb.__version__)
print(langchain.__version__)
@app.post("/chat")
async def chat(
request: Request
): #-> StreamingResponse:
# set up chain taking in user query
from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.llms import OpenAI
# set up LLM with wandb tracer hooked up
llm = ChatOpenAI(
streaming=True, # must stream
openai_api_key="XXX"
)
# set up chain taking in user query
llm_chain = LLMChain(
llm=llm,
prompt=PromptTemplate.from_template(request.query)
)
response = StreamingResponse.from_chain(
llm_chain, # pass the chain
{}, # no input variables needed for this example
media_type="text/event-stream", # should stream
)
return response
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8010)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment