Created
March 17, 2024 01:08
-
-
Save ugai/9154d3b0a57cb39c6fcf554d53ec4c86 to your computer and use it in GitHub Desktop.
Conversation Retrieval Chain Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Conversation Retrieval Chain | |
https://python.langchain.com/docs/get_started/quickstart#conversation-retrieval-chain | |
""" | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
def run_conversation_retrieval_chain( | |
url: str, chat_history: list, chat_input: str | |
) -> dict: | |
llm = ChatOpenAI() | |
embeddings = OpenAIEmbeddings() | |
# Webページのテキストを読み込んで、文章を適度な大きさに分割 | |
loader = WebBaseLoader(url) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter() | |
documents = text_splitter.split_documents(docs) | |
# 文章を特徴ベクトルに変換し、インメモリのベクトルストアに取り込み | |
vector = FAISS.from_documents(documents, embeddings) | |
retriever = vector.as_retriever() # 他のチェインから使えるように変換 | |
# 会話履歴を踏まえてベクトルストアの情報を検索するチェイン | |
retriever_chain_prompt = ChatPromptTemplate.from_messages( | |
[ | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
( | |
"user", | |
"上記の会話から、会話に関連する情報を取得するための検索クエリを生成してください。", | |
), | |
] | |
) | |
retriever_chain = create_history_aware_retriever( | |
llm, retriever, retriever_chain_prompt | |
) | |
# 検索結果を基にユーザの質問に答えるチェイン | |
document_chain_prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"下記のコンテキストを基に、ユーザの質問に答えてください。t:\n\n{context}", | |
), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
] | |
) | |
document_chain = create_stuff_documents_chain(llm, document_chain_prompt) | |
# 検索 -> 回答のチェインを作成 | |
retrieval_chain = create_retrieval_chain(retriever_chain, document_chain) | |
# 問い合わせ | |
result = retrieval_chain.invoke({"chat_history": chat_history, "input": chat_input}) | |
return result | |
def main(): | |
url = "https://huggingface.co/docs/transformers/v4.38.2/en/perf_train_gpu_many" | |
chat_history = [ | |
HumanMessage(content="HuggingFaceの並列実行について教えてください。"), | |
AIMessage(content="はい。もちろん!"), | |
] | |
chat_input = "概要とどんな仕組みなのかを中学生にもわかるように教えてください。その後、専門家向けの解説もしてください。最後に、全ての手法を箇条書きで挙げて、それぞれの良い点と悪い点を示してください。" | |
result = run_conversation_retrieval_chain(url, chat_history, chat_input) | |
for k, v in result.items(): | |
print(f"{k}: {v}") | |
# answer: | |
# ### 中学生向けの説明: | |
# Hugging Faceの並列実行は、モデルのトレーニングを複数のGPUで同時に行うことです。これにより、大きなモデルをより早くトレーニングすることができます。 | |
# 複数のGPUを使用することで、モデルの重みを複数の部分に分割して処理し、それぞれの部分を同時に学習させることができます。 | |
# | |
# ### 専門家向けの解説: | |
# Hugging Faceの並列実行では、複数のGPUを使用してトレーニングを効率化します。主な並列化戦略には、データ並列化、テンソル並列化、 | |
# パイプライン並列化などがあります。これらの戦略を組み合わせることで、効果的なトレーニングが可能となります。 | |
# | |
# ### 手法の比較: | |
# | |
# - データ並列化: | |
# - 良い点:複数のGPUを使用してデータを並列に処理し、トレーニング時間を短縮できる。 | |
# - 悪い点:通信のオーバーヘッドが発生しやすく、大規模なモデルには適しているが、小規模なモデルでは効果が限定される。 | |
# | |
# - テンソル並列化: | |
# - 良い点:モデルの特定の部分を異なるGPUで処理し、メモリの使用効率を向上させる。 | |
# - 悪い点:複雑な実装が必要であり、効果がモデルの構造に依存する。 | |
# | |
# - パイプライン並列化: | |
# - 良い点:モデルを複数のステージに分割し、同時に処理することでトレーニング速度を向上させる。 | |
# - 悪い点:実装が複雑であり、適切なパイプラインの設計が求められる。 | |
# | |
# 以上が、Hugging Faceの並列実行における主な手法の比較とそれぞれの良い点と悪い点です。 | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment