Skip to content

Instantly share code, notes, and snippets.

@Princekrampah
Created September 1, 2024 07:40
Show Gist options
  • Save Princekrampah/fca985086a34462444810f2385d3d6de to your computer and use it in GitHub Desktop.
Save Princekrampah/fca985086a34462444810f2385d3d6de to your computer and use it in GitHub Desktop.
Main ETL code for GraphRAG Project Series
import os
import logging
from retry import retry
from dotenv import load_dotenv
from neo4j import GraphDatabase
import pandas as pd
# Load environment variables
load_dotenv()
# file paths
CATEGORY_CSV_FILEPATH = os.getenv("CATEGORY_CSV_FILE_PATH")
PRODUCT_CSV_FILE_PATH = os.getenv("PRODUCT_CSV_FILE_PATH")
SUPPLIER_CSV_FILE_PATH = os.getenv("SUPPLIER_CSV_FILE_PATH")
ORDER_CSV_FILE_PATH = os.getenv("ORDER_CSV_FILE_PATH")
ORDER_DETAILS_CSV_FILE_PATH = os.getenv("ORDER_DETAILS_CSV_FILE_PATH")
SHIPPER_CSV_FILE_PATH = os.getenv("SHIPPERS_CSV_FILE_PATH")
EMPLOYEE_CSV_FILE_PATH = os.getenv("EMPLOYEE_CSV_FILE_PATH")
CUSTOMER_CSV_FILE_PATH = os.getenv("CUSTOMER_CSV_FILE_PATH")
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
LOGGER = logging.getLogger(__name__)
NODES = ["Product", "Category", "Supplier",
"Order", "Employee", "Customer", "Shipper"]
@retry(tries=100, delay=10)
def set_uniqueness_constraints(tx, node):
"""
Creates a uniqueness constraint for a specified node label in Neo4j.
This function defines a Cypher query to create a uniqueness constraint for
the specified node label. The constraint ensures that the `id` property of
nodes with the given label is unique across the database.
The function uses the `CREATE CONSTRAINT IF NOT EXISTS` Cypher statement to
ensure that the constraint is only created if it does not already exist.
Args:
tx (neo4j.Transaction): The Neo4j transaction object used to execute the query.
node (str): The label of the node for which to create the uniqueness constraint.
"""
query = f"""CREATE CONSTRAINT IF NOT EXISTS FOR (n:{node})
REQUIRE n.id IS UNIQUE;"""
_ = tx.run(query, {})
@retry(tries=100, delay=10)
def create_uniqueness_constraints():
"""
Connects to the Neo4j database and creates uniqueness constraints on specified nodes.
This function establishes a connection to the Neo4j database using the provided
URI and authentication credentials. It then iterates over a list of node labels
and applies uniqueness constraints to each node type by calling the `set_uniqueness_constraints` function.
The uniqueness constraints ensure that the specified properties for each node type
are unique across the database, preventing duplicate entries.
This function retries up to 100 times with a delay of 10 seconds between attempts
if it encounters connection issues or other transient failures.
"""
LOGGER.info("Connecting to Neo4j")
driver = GraphDatabase.driver(
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)
LOGGER.info("Creating uniqueness constraints")
with driver.session() as session:
for node in NODES:
session.execute_write(set_uniqueness_constraints, node)
@retry(tries=100, delay=10)
def create_manager(tx, row):
"""
Creates or updates an Employee node in the Neo4j database.
This function uses the Cypher MERGE statement to ensure that an Employee node
with the specified properties is created. If an Employee node with the given
employeeID already exists, it will be updated with the provided properties.
"""
tx.run("""
MERGE (e:Employee {
employeeID: $employeeID,
lastName: $lastName,
firstName: $firstName,
title: $title,
titleOfCourtesy: $titleOfCourtesy,
birthDate: $birthDate,
hireDate: $hireDate,
address: $address_y,
city: $city_y,
region: $region_y,
postalCode: $postalCode_y,
country: $country_y,
homePhone: $homePhone,
extension: $extension,
photo: $photo,
notes: $notes,
photoPath: $photoPath
})
""", row)
def process_product_category_supplier_csv(
product_file_path: str,
category_file_path: str,
supplier_file_path: str,
) -> pd.DataFrame:
"""
Reads and processes CSV files containing product, category, and supplier data, merging them into a single DataFrame.
This function reads data from three CSV files (product, category, and supplier), merges the data based on the
'categoryID' and 'supplierID' columns, and cleans the resulting DataFrame by replacing missing values in specific
columns with 'Unknown'.
:param product_file_path: The file path to the CSV file containing product data.
:param category_file_path: The file path to the CSV file containing category data.
:param supplier_file_path: The file path to the CSV file containing supplier data.
:return: A pandas DataFrame containing the merged product, category, and supplier data.
"""
try:
LOGGER.info(f"Reading data from {product_file_path}")
product_df = pd.read_csv(product_file_path)
LOGGER.info(f"Reading data from {category_file_path}")
category_df = pd.read_csv(category_file_path)
LOGGER.info("Merging product and category data")
product_category_df = pd.merge(
product_df, category_df, on='categoryID')
LOGGER.info(f"Reading data from {supplier_file_path}")
supplier_df = pd.read_csv(supplier_file_path)
LOGGER.info("Merging product, category and supplier data")
product_category_supplier_df = pd.merge(
product_category_df, supplier_df, on='supplierID', how='left')
LOGGER.info("Cleaning data, replacing NA values with Unknown")
product_category_supplier_df["region"] = product_category_supplier_df["region"].replace({
pd.NA: "Unknown"})
product_category_supplier_df["fax"] = product_category_supplier_df["fax"].replace({
pd.NA: "Unknown"})
product_category_supplier_df["homePage"] = product_category_supplier_df["homePage"].replace({
pd.NA: "Unknown"})
return product_category_supplier_df
except Exception as e:
LOGGER.error(f"Error reading CSV data: {e}")
def insert_data(tx, row):
"""
Inserts product, category, and supplier data into a Neo4j graph database.
This function creates a product node, merges category and supplier nodes, and
establishes relationships between the product and its category and supplier in
the Neo4j graph. The data is passed as a dictionary in the `row` parameter.
:param tx: The transaction object used to execute the Cypher queries in the Neo4j database.
:param row: A dictionary containing the product, category, and supplier data to be inserted.
"""
tx.run('''
CREATE (product:Product {
productID: $productID,
productName: $productName,
supplierID: $supplierID,
categoryID: $categoryID,
quantityPerUnit: $quantityPerUnit,
unitPrice: $unitPrice,
unitsInStock: $unitsInStock,
unitsOnOrder: $unitsOnOrder,
reorderLevel: $reorderLevel,
discontinued: $discontinued
})
MERGE (category:Category {
categoryID: $categoryID,
categoryName: $categoryName,
description: $description,
picture: $picture
})
MERGE (supplier:Supplier {
supplierID: $supplierID,
companyName: $companyName,
contactName: $contactName,
contactTitle: $contactTitle,
address: $address,
city: $city,
region: $region,
postalCode: $postalCode,
country: $country,
phone: $phone,
fax: $fax,
homePage: $homePage
})
CREATE (product)-[:PART_OF]->(category)
CREATE (product)-[:SUPPLIED_BY]->(supplier)
''', row)
@retry(tries=100, delay=10)
def load_product_category_supply_into_graph(
product_category_supplier_df: pd.DataFrame,
):
"""
Loads product, category, and supplier data into a Neo4j graph database.
This function connects to a Neo4j database and inserts data from a pandas DataFrame
that contains product, category, and supplier information. The data is inserted
into the graph using a write transaction. The function uses retry logic to ensure
successful data insertion, retrying up to 100 times with a 10-second delay between attempts.
:param product_category_supplier_df: A pandas DataFrame containing the product, category,
and supplier data to be inserted into the graph.
"""
LOGGER.info("Connecting to Neo4j")
driver = GraphDatabase.driver(
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)
LOGGER.info("Inserting data into Neo4j")
with driver.session() as session:
for _, row in product_category_supplier_df.iterrows():
session.execute_write(insert_data, row.to_dict())
LOGGER.info("Data inserted into Neo4j")
driver.close()
def process_order_order_details_product_shipper_employee_customer_csv(
order_file_path: str,
order_details_file_path: str,
customer_file_path: str,
shipper_file_path: str,
employee_file_path: str,
) -> pd.DataFrame:
"""
Processes and merges data from multiple CSV files containing order, order details,
customer, shipper, and employee information, and prepares it for further analysis.
This function reads the provided CSV files, merges them into a single DataFrame,
and cleans the data by replacing missing values. The final DataFrame is returned
for use in other operations, such as inserting the data into a Neo4j database.
:param order_file_path: The file path to the orders CSV file.
:param order_details_file_path: The file path to the order details CSV file.
:param customer_file_path: The file path to the customers CSV file.
:param shipper_file_path: The file path to the shippers CSV file.
:param employee_file_path: The file path to the employees CSV file.
:return: A pandas DataFrame containing the merged and cleaned data.
"""
LOGGER.info(f"Creating")
driver = GraphDatabase.driver(
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)
LOGGER.info(f"Reading data...")
orders_df = pd.read_csv(order_file_path)
order_details_df = pd.read_csv(order_details_file_path)
customer_df = pd.read_csv(customer_file_path)
shipper_df = pd.read_csv(shipper_file_path)
employee_df = pd.read_csv(employee_file_path)
LOGGER.info("Merging order and order details data...")
orders_order_details_df = pd.merge(
orders_df,
order_details_df,
on='orderID',
how='left'
)
LOGGER.info("Merging order and order details data with customer data...")
orders_order_details_customer_df = pd.merge(
orders_order_details_df,
customer_df,
on='customerID',
how='left'
)
LOGGER.info(
"Merging order and order details data with customer data and shipper data...")
orders_order_details_customer_shipper_df = pd.merge(
orders_order_details_customer_df,
shipper_df,
left_on='shipVia',
right_on="shipperID",
how='left'
)
LOGGER.info(
"Merging order and order details data with customer data, shipper data and employee data...")
orders_order_details_customer_shipper_employee_df = pd.merge(
orders_order_details_customer_shipper_df,
employee_df,
left_on='employeeID',
right_on='employeeID',
how='left'
)
LOGGER.info("Cleaning data...")
orders_order_details_customer_shipper_employee_df.replace(
{pd.NA: "Unknown"}, inplace=True)
# Change to integer
orders_order_details_customer_shipper_employee_df["reportsTo"] = orders_order_details_customer_shipper_employee_df["reportsTo"].astype(
'Int64')
orders_order_details_customer_shipper_employee_df["reportsTo"]
# Replace missing values
orders_order_details_customer_shipper_employee_df["reportsTo"] = orders_order_details_customer_shipper_employee_df["reportsTo"].replace({
pd.NA: 2})
return orders_order_details_customer_shipper_employee_df
def insert_manager_record(
orders_order_details_customer_shipper_employee_df: pd.DataFrame
) -> None:
"""
Inserts records of employees with the title 'Vice President' into the Neo4j database.
This function connects to a Neo4j database and creates nodes for employees who have
the title 'Vice President'. It filters the provided DataFrame to find these records
and then inserts them into the database using a Neo4j transaction.
:param orders_order_details_customer_shipper_employee_df: A pandas DataFrame containing
employee records, including titles and other relevant information.
"""
LOGGER.info("Connecting to Neo4j")
driver = GraphDatabase.driver(
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)
LOGGER.info("Creating vice president node...")
vice_president = orders_order_details_customer_shipper_employee_df[
orders_order_details_customer_shipper_employee_df["title"] == "Vice President"]
LOGGER.info("Inserting data into Neo4j")
with driver.session() as session:
for _, row in vice_president.iterrows():
session.write_transaction(create_manager, row.to_dict())
LOGGER.info("Vice president node created")
driver.close()
def order_order_details_shippers_employees_and_customer_data_ingester(tx, row):
"""
Ingests order-related data into the Neo4j database, creating and merging nodes and relationships
for orders, products, customers, shippers, and employees.
This function handles the creation of nodes for Orders, Customers, Shippers, and Employees,
and establishes relationships between them based on the provided data. It also links employees
to their managers if applicable.
:param tx: The Neo4j transaction context to execute the query.
:param row: A dictionary containing the data for the order, product, customer, shipper,
and employee nodes, including relationships such as manager reporting.
"""
tx.run("""
CREATE (o:Order {
orderID: $orderID,
orderDate: $orderDate,
requiredDate: $requiredDate,
shippedDate: $shippedDate,
shipVia: $shipVia,
freight: $freight,
shipName: $shipName,
shipAddress: $shipAddress,
shipCity: $shipCity,
shipRegion: $shipRegion,
shipPostalCode: $shipPostalCode,
shipCountry: $shipCountry
})
WITH o
MATCH (p:Product { productID: $productID })
WITH p, o
MERGE (c:Customer {
customerID: $customerID,
companyName: $companyName_x,
contactName: $contactName,
contactTitle: $contactTitle,
address: $address_x,
city: $city_x,
region: $region_x,
postalCode: $postalCode_x,
country: $country_x,
phone: $phone_x,
fax: $fax
})
WITH c, p, o
MERGE (s:Shipper {
shipperID: $shipperID,
companyName: $companyName_y,
phone: $phone_y
})
WITH s, c, p, o
MERGE (e:Employee {
employeeID: $employeeID,
lastName: $lastName,
firstName: $firstName,
title: $title,
titleOfCourtesy: $titleOfCourtesy,
birthDate: $birthDate,
hireDate: $hireDate,
address: $address_y,
city: $city_y,
region: $region_y,
postalCode: $postalCode_y,
country: $country_y,
homePhone: $homePhone,
extension: $extension,
photo: $photo,
notes: $notes,
photoPath: $photoPath
})
WITH e, s, c, p, o
MATCH (m:Employee { employeeID: $reportsTo }) // Assuming reportsTo is the ID of the manager
WITH m, e, s, c, p, o
MERGE (e)-[:REPORTS_TO]->(m)
MERGE (o)-[:INCLUDES]->(p)
MERGE (o)-[:ORDERED_BY]->(c)
MERGE (o)-[:SHIPPED_BY]->(s)
MERGE (o)-[:PROCESSED_BY]->(e)
""", parameters=row)
@retry(tries=100, delay=10)
def load_order_order_details_shippers_employees_and_customer_data_into_graph(
orders_order_details_customer_shipper_employee_df: pd.DataFrame,
):
"""
Load order, order details, shippers, employees and customer data into Neo4j.
"""
LOGGER.info("Connecting to Neo4j")
driver = GraphDatabase.driver(
NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
)
LOGGER.info("Inserting data into Neo4j")
with driver.session() as session:
for _, row in orders_order_details_customer_shipper_employee_df.iterrows():
session.execute_write(
order_order_details_shippers_employees_and_customer_data_ingester, row.to_dict())
LOGGER.info("Data inserted into Neo4j")
driver.close()
def main():
LOGGER.info(
"\n\n+++++++++++++++++ Starting ETL process +++++++++++++++++\n\n")
create_uniqueness_constraints()
product_category_supplier_df = process_product_category_supplier_csv(
PRODUCT_CSV_FILE_PATH, CATEGORY_CSV_FILEPATH, SUPPLIER_CSV_FILE_PATH
)
load_product_category_supply_into_graph(product_category_supplier_df)
orders_order_details_customer_shipper_employee_df = process_order_order_details_product_shipper_employee_customer_csv(
ORDER_CSV_FILE_PATH,
ORDER_DETAILS_CSV_FILE_PATH,
CUSTOMER_CSV_FILE_PATH,
SHIPPER_CSV_FILE_PATH,
EMPLOYEE_CSV_FILE_PATH
)
insert_manager_record(orders_order_details_customer_shipper_employee_df)
load_order_order_details_shippers_employees_and_customer_data_into_graph(
orders_order_details_customer_shipper_employee_df[:250])
LOGGER.info(
"\n\n+++++++++++++++++ ETL process completed +++++++++++++++++\n\n")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment