Created
July 7, 2023 06:13
-
-
Save samhita-alla/adf9b977c04fcc8ea21353f3db00c1e3 to your computer and use it in GitHub Desktop.
LangChain x Flyte: Streamlined Ingestion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
from functools import partial | |
import flytekit | |
from flytekit import ImageSpec, Resources, Secret, dynamic, map_task, task, workflow | |
load_data_image = ImageSpec( | |
name="langchain-flyte-load-data", | |
packages=[ | |
"langchain", | |
"yt_dlp", | |
"pydub", | |
"openai", | |
], | |
apt_packages=["ffmpeg"], | |
registry="ghcr.io/samhita-alla", | |
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0", | |
) | |
split_data_image = ImageSpec( | |
name="langchain-flyte-split-data", | |
packages=["langchain"], | |
registry="ghcr.io/samhita-alla", | |
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0", | |
) | |
store_in_vectordb_image = ImageSpec( | |
name="langchain-flyte-vectordb", | |
packages=[ | |
"langchain", | |
"pinecone-client", | |
"huggingface_hub", | |
"sentence_transformers", | |
], | |
registry="ghcr.io/samhita-alla", | |
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0", | |
) | |
query_image = ImageSpec( | |
name="langchain-flyte-query", | |
packages=[ | |
"langchain", | |
"pinecone-client", | |
"huggingface_hub", | |
"sentence_transformers", | |
"openai", | |
"spacy", | |
"textstat", | |
"https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz", | |
], | |
registry="ghcr.io/samhita-alla", | |
base_image="ghcr.io/flyteorg/flytekit:py3.11-1.7.0", | |
) | |
SECRET_GROUP = "arn:aws:secretsmanager:us-east-2:356633062068:secret" | |
SECRET_KEY = "flyte_langchain-YtD8OW" | |
@task( | |
cache=True, | |
cache_version="1.0", | |
container_image=load_data_image, | |
requests=Resources(mem="5Gi"), | |
secret_requests=[ | |
Secret( | |
group=SECRET_GROUP, | |
key=SECRET_KEY, | |
mount_requirement=Secret.MountType.FILE, | |
), | |
], | |
) | |
def load_data(url: str) -> str: | |
import openai | |
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader | |
from langchain.document_loaders.generic import GenericLoader | |
from langchain.document_loaders.parsers import OpenAIWhisperParser | |
openai.api_key = json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["openai_api_key"] | |
# Directory to save audio files | |
save_dir = os.path.join(flytekit.current_context().working_directory, "youtube") | |
# Transcribe the videos to text | |
loader = GenericLoader(YoutubeAudioLoader([url], save_dir), OpenAIWhisperParser()) | |
docs = loader.load() | |
combined_docs = [doc.page_content for doc in docs] | |
text = " ".join(combined_docs) | |
return text | |
@task(cache=True, cache_version="1.0", container_image=split_data_image) | |
def split_data(text: str) -> list[str]: | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150) | |
splits = text_splitter.split_text(text) | |
return splits | |
@task( | |
cache=True, | |
cache_version="1.0", | |
container_image=store_in_vectordb_image, | |
requests=Resources(mem="2Gi"), | |
secret_requests=[ | |
Secret( | |
group=SECRET_GROUP, | |
key=SECRET_KEY, | |
mount_requirement=Secret.MountType.FILE, | |
), | |
], | |
) | |
def store_in_vectordb(splits: list[str], index_name: str) -> str: | |
import pinecone | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Pinecone | |
pinecone.init( | |
api_key=json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["pinecone_api_key"], | |
environment=json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["pinecone_environment"], | |
) | |
huggingface_embeddings = HuggingFaceEmbeddings( | |
cache_folder=os.path.join( | |
flytekit.current_context().working_directory, "embeddings-cache-folder" | |
) | |
) | |
Pinecone.from_texts( | |
texts=splits, embedding=huggingface_embeddings, index_name=index_name | |
) | |
return f"Data is stored in the vectordb." | |
@task( | |
disable_deck=False, | |
secret_requests=[ | |
Secret( | |
group=SECRET_GROUP, | |
key=SECRET_KEY, | |
mount_requirement=Secret.MountType.FILE, | |
), | |
], | |
container_image=query_image, | |
requests=Resources(mem="5Gi"), | |
) | |
def query_vectordb(index_name: str, query: str) -> str: | |
import pinecone | |
from langchain.callbacks import FlyteCallbackHandler | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Pinecone | |
pinecone.init( | |
api_key=json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["pinecone_api_key"], | |
environment=json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["pinecone_environment"], | |
) | |
huggingface_embeddings = HuggingFaceEmbeddings( | |
cache_folder=os.path.join( | |
flytekit.current_context().working_directory, "embeddings-cache-folder" | |
) | |
) | |
vectordb = Pinecone.from_existing_index(index_name, huggingface_embeddings) | |
retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": 2}) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
callbacks=[FlyteCallbackHandler()], | |
temperature=0, | |
openai_api_key=json.loads( | |
flytekit.current_context().secrets.get(SECRET_GROUP, SECRET_KEY) | |
)["openai_api_key"], | |
), | |
chain_type="stuff", | |
retriever=retriever, | |
) | |
result = qa_chain.run(query) | |
return result | |
@workflow | |
def flyte_youtube_embed_wf( | |
index_name: str = "flyte-youtube-data", | |
urls: list[str] = [ | |
"https://youtu.be/CNmO1q3MamM", | |
"https://youtu.be/8rLj_YVOpzE", | |
"https://youtu.be/sGqS8PFQz6c", | |
"https://youtu.be/1668vZczslw", | |
"https://youtu.be/NrFOXQKrREA", | |
"https://youtu.be/4ktHNeT8kq4", | |
"https://youtu.be/gMyTz8gKWVc", | |
], | |
) -> list[str]: | |
text = map_task(load_data)(url=urls) | |
splits = map_task(split_data)(text=text) | |
partial_store_in_vectordb = partial(store_in_vectordb, index_name=index_name) | |
return map_task(partial_store_in_vectordb)(splits=splits) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment