Skip to content

Instantly share code, notes, and snippets.

@chezou
Created January 29, 2024 01:21
Show Gist options
  • Save chezou/2a6a58cd6bd21b001ca4d3c8d1318e14 to your computer and use it in GitHub Desktop.
Save chezou/2a6a58cd6bd21b001ca4d3c8d1318e14 to your computer and use it in GitHub Desktop.
RAG Chatbot using Confluence

confluence-rag

This is a sample application for RAG (retrieval augmented generation) with Confluence data.

It depends on rye for package management.

Since gist doesn't allow to crete a directory, make sure to create confluence_rag directory and move Python files under it.

mkdir confluence_rag
mv confluence_rag-app.py confluence_rag/app.py
mv confluence_rag-qa.py confluence_rag/qa.py
mv confluence_rag-__init__.py confluence_rag/__init__.py

Before launch, you should modify confluence_rag/app.py, especially Confluence's username and space name.

Also, you have to set attlasian API token in ATL_TOKEN in environment variable.

You can launch a web app with the following command:

rye sync
export ATL_TOKEN="...."
rye run rag

Or, you can launch as:

python -m venv .venv
source .venv/bin/activate
pip install -e .
python confluence_rag/app.py
#!/usr/bin/env python3
import os
import gradio as gr
from confluence_rag.qa import ConfluenceQA
token = os.environ.get("ATL_TOKEN")
with gr.Blocks(title="Confluence Chatbot") as demo:
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox()
clear = gr.Button("Clear")
qa = ConfluenceQA(profile_name="rd")
qa.load_confluence_documents(
username="REPLACE_TO_YOUR_EMAIL_ADDRESS",
token=token,
space_key="REPLACE_TO_YOUR_SPACE_NAME",
persist_directory="tmp/en",
)
qa.retrieval_qa_chain()
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
print("Question: ", history[-1][0])
bot_message = qa.answer_confluence(history[-1][0])
print("Response: ", bot_message)
history[-1][1] = ""
history[-1][1] += bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch()
import os
from typing import Optional
import boto3
from langchain.chains import RetrievalQA
from langchain.document_loaders import ConfluenceLoader
from langchain.embeddings import BedrockEmbeddings
from langchain.llms.bedrock import Bedrock
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
class ConfluenceQA:
def __init__(
self,
profile_name: str,
):
self.vectordb = None
self.client = self._set_client(profile_name)
self.llm = Bedrock(
model_id="anthropic.claude-v2",
client=self.client,
model_kwargs={"max_tokens_to_sample": 1000},
)
self.embedding = BedrockEmbeddings(
model_id="amazon.titan-embed-text-v1", client=self.client
)
def _set_client(
self,
profile_name: str,
region_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
):
try:
if profile_name:
session = boto3.Session(profile_name=profile_name)
else:
session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
if endpoint_url:
client_params["endpoint_url"] = endpoint_url
self.client = session.client("bedrock-runtime", **client_params)
return self.client
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client."
) from e
def load_confluence_documents(
self,
persist_directory: str,
space_key: Optional[str] = None,
username: Optional[str] = None,
token: Optional[str] = None,
max_pages: int = 2000,
force_reload: bool = False,
) -> list[Document]:
if persist_directory and os.path.exists(persist_directory) and not force_reload:
self.vectordb = Chroma(
persist_directory=persist_directory, embedding_function=self.embedding
)
else:
loader = ConfluenceLoader(
url="https://treasure-data.atlassian.net/wiki",
username=username, # "[email protected]"
api_key=token,
)
documents = loader.load(space_key=space_key, max_pages=max_pages)
# Default splitter for load_and_split
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
)
docs = text_splitter.split_documents(documents)
self.vectordb = Chroma.from_documents(
documents=docs,
embedding=self.embedding,
persist_directory=persist_directory,
)
def retrieval_qa_chain(self):
self.retriever = self.vectordb.as_retriever()
self.qa = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, return_source_documents=True)
def answer_confluence(self, question: str) -> str:
return self.qa.run(question)
[project]
name = "confluence-rag"
version = "0.1.0"
description = "RAG based chat app based on confluence"
authors = [
{ name = "Aki Ariga", email = "[email protected]" }
]
dependencies = [
"langchain>=0.0.27",
"atlassian-python-api>=3.41.3",
"beautifulsoup4>=4.12.2",
"lxml>=4.9.3",
"boto3>=1.28.79",
"chromadb>=0.4.15",
"gradio>=4.1.1",
]
readme = "README.md"
requires-python = ">= 3.8"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"ruff>=0.1.4",
]
[tool.rye.scripts]
rag = { cmd = ["python", "confluence_rag/app.py"] }
[tool.hatch.metadata]
allow-direct-references = true
[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I"]
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
-e file:.
aiofiles==23.2.1
aiohttp==3.8.6
aiosignal==1.3.1
altair==5.1.2
annotated-types==0.6.0
anyio==3.7.1
async-timeout==4.0.3
atlassian-python-api==3.41.3
attrs==23.1.0
backoff==2.2.1
bcrypt==4.0.1
beautifulsoup4==4.12.2
boto3==1.28.79
botocore==1.31.79
cachetools==5.3.2
certifi==2023.7.22
charset-normalizer==3.3.2
chroma-hnswlib==0.7.3
chromadb==0.4.15
click==8.1.7
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.2.0
cycler==0.12.1
dataclasses-json==0.6.1
deprecated==1.2.14
fastapi==0.104.1
ffmpy==0.3.1
filelock==3.13.1
flatbuffers==23.5.26
fonttools==4.44.0
frozenlist==1.4.0
fsspec==2023.10.0
google-auth==2.23.4
googleapis-common-protos==1.61.0
gradio==4.1.1
gradio-client==0.7.0
grpcio==1.59.2
h11==0.14.0
httpcore==1.0.1
httptools==0.6.1
httpx==0.25.1
huggingface-hub==0.17.3
humanfriendly==10.0
idna==3.4
importlib-metadata==6.8.0
importlib-resources==6.1.1
jinja2==3.1.2
jmespath==1.0.1
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
kiwisolver==1.4.5
kubernetes==28.1.0
langchain==0.0.331
langsmith==0.0.59
lxml==4.9.3
markdown-it-py==3.0.0
markupsafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.1
mdurl==0.1.2
monotonic==1.6
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
numpy==1.26.1
oauthlib==3.2.2
onnxruntime==1.16.1
opentelemetry-api==1.20.0
opentelemetry-exporter-otlp-proto-common==1.20.0
opentelemetry-exporter-otlp-proto-grpc==1.20.0
opentelemetry-proto==1.20.0
opentelemetry-sdk==1.20.0
opentelemetry-semantic-conventions==0.41b0
orjson==3.9.10
overrides==7.4.0
packaging==23.2
pandas==2.1.2
pillow==10.1.0
posthog==3.0.2
protobuf==4.25.0
pulsar-client==3.3.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydantic==2.4.2
pydantic-core==2.10.1
pydub==0.25.1
pygments==2.16.1
pyparsing==3.1.1
pypika==0.48.9
python-dateutil==2.8.2
python-dotenv==1.0.0
python-multipart==0.0.6
pytz==2023.3.post1
pyyaml==6.0.1
referencing==0.30.2
requests==2.31.0
requests-oauthlib==1.3.1
rich==13.6.0
rpds-py==0.12.0
rsa==4.9
ruff==0.1.4
s3transfer==0.7.0
semantic-version==2.10.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
sqlalchemy==2.0.23
starlette==0.27.0
sympy==1.12
tenacity==8.2.3
tokenizers==0.14.1
tomlkit==0.12.0
toolz==0.12.0
tqdm==4.66.1
typer==0.9.0
typing-extensions==4.8.0
typing-inspect==0.9.0
tzdata==2023.3
urllib3==1.26.18
uvicorn==0.24.0.post1
uvloop==0.19.0
watchfiles==0.21.0
websocket-client==1.6.4
websockets==11.0.3
wrapt==1.15.0
yarl==1.9.2
zipp==3.17.0
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
-e file:.
aiofiles==23.2.1
aiohttp==3.8.6
aiosignal==1.3.1
altair==5.1.2
annotated-types==0.6.0
anyio==3.7.1
async-timeout==4.0.3
atlassian-python-api==3.41.3
attrs==23.1.0
backoff==2.2.1
bcrypt==4.0.1
beautifulsoup4==4.12.2
boto3==1.28.79
botocore==1.31.79
cachetools==5.3.2
certifi==2023.7.22
charset-normalizer==3.3.2
chroma-hnswlib==0.7.3
chromadb==0.4.15
click==8.1.7
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.2.0
cycler==0.12.1
dataclasses-json==0.6.1
deprecated==1.2.14
fastapi==0.104.1
ffmpy==0.3.1
filelock==3.13.1
flatbuffers==23.5.26
fonttools==4.44.0
frozenlist==1.4.0
fsspec==2023.10.0
google-auth==2.23.4
googleapis-common-protos==1.61.0
gradio==4.1.1
gradio-client==0.7.0
grpcio==1.59.2
h11==0.14.0
httpcore==1.0.1
httptools==0.6.1
httpx==0.25.1
huggingface-hub==0.17.3
humanfriendly==10.0
idna==3.4
importlib-metadata==6.8.0
importlib-resources==6.1.1
jinja2==3.1.2
jmespath==1.0.1
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
kiwisolver==1.4.5
kubernetes==28.1.0
langchain==0.0.331
langsmith==0.0.59
lxml==4.9.3
markdown-it-py==3.0.0
markupsafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.1
mdurl==0.1.2
monotonic==1.6
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
numpy==1.26.1
oauthlib==3.2.2
onnxruntime==1.16.1
opentelemetry-api==1.20.0
opentelemetry-exporter-otlp-proto-common==1.20.0
opentelemetry-exporter-otlp-proto-grpc==1.20.0
opentelemetry-proto==1.20.0
opentelemetry-sdk==1.20.0
opentelemetry-semantic-conventions==0.41b0
orjson==3.9.10
overrides==7.4.0
packaging==23.2
pandas==2.1.2
pillow==10.1.0
posthog==3.0.2
protobuf==4.25.0
pulsar-client==3.3.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydantic==2.4.2
pydantic-core==2.10.1
pydub==0.25.1
pygments==2.16.1
pyparsing==3.1.1
pypika==0.48.9
python-dateutil==2.8.2
python-dotenv==1.0.0
python-multipart==0.0.6
pytz==2023.3.post1
pyyaml==6.0.1
referencing==0.30.2
requests==2.31.0
requests-oauthlib==1.3.1
rich==13.6.0
rpds-py==0.12.0
rsa==4.9
s3transfer==0.7.0
semantic-version==2.10.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
sqlalchemy==2.0.23
starlette==0.27.0
sympy==1.12
tenacity==8.2.3
tokenizers==0.14.1
tomlkit==0.12.0
toolz==0.12.0
tqdm==4.66.1
typer==0.9.0
typing-extensions==4.8.0
typing-inspect==0.9.0
tzdata==2023.3
urllib3==1.26.18
uvicorn==0.24.0.post1
uvloop==0.19.0
watchfiles==0.21.0
websocket-client==1.6.4
websockets==11.0.3
wrapt==1.15.0
yarl==1.9.2
zipp==3.17.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment