Created
April 9, 2024 05:49
-
-
Save BrandonStudio/638a629911e47fee29175ca5c0b7430c to your computer and use it in GitHub Desktop.
LangChain batch job progress bar callback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict | |
from uuid import UUID | |
from tqdm.auto import tqdm | |
from langchain_core.callbacks import BaseCallbackHandler | |
class BatchCallback(BaseCallbackHandler): | |
def __init__(self, total: int): | |
super().__init__() | |
self.count = 0 | |
self.progress_bar = tqdm(total=total) # define a progress bar | |
# Override on_llm_end method. This is called after every response from LLM | |
def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any: | |
self.count += 1 | |
self.progress_bar.update(1) | |
def __enter__(self): | |
self.progress_bar.__enter__() | |
return self | |
def __exit__(self, exc_type, exc_value, exc_traceback): | |
self.progress_bar.__exit__(exc_type, exc_value, exc_traceback) | |
def __del__(self): | |
self.progress_bar.__del__() | |
# Assume your chain is `chain`, inputs is `inputs` | |
with BatchCallback(len(inputs)) as cb: # init callback | |
chain.batch(inputs, config={"callbacks": [cb]}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
langchain-ai/langchain#6053 (comment)