Skip to content

Instantly share code, notes, and snippets.

@angusdev
Last active August 18, 2024 09:56
Show Gist options
  • Save angusdev/9ca9285d14e4d3acaf2259618215eff3 to your computer and use it in GitHub Desktop.
Save angusdev/9ca9285d14e4d3acaf2259618215eff3 to your computer and use it in GitHub Desktop.
from typing import List, Dict, Any, Tuple, Optional
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine, ResultProxy
import pandas as pd
from pyspark.sql import SparkSession, DataFrame as SparkDataFrame
# Base class DataAccess using SQLAlchemy
class DataAccess:
def __init__(self, conn_str: str):
self.engine: Engine = create_engine(conn_str)
def read_dict(self, query: str) -> List[Dict[str, Any]]:
with self.engine.connect() as conn:
result: ResultProxy = conn.execute(text(query))
return [dict(row) for row in result.fetchall()]
def read_pandas(self, query: str) -> pd.DataFrame:
with self.engine.connect() as conn:
return pd.read_sql(query, conn)
def write_dict(self, data: List[Dict[str, Any]], table_name: str) -> None:
with self.engine.connect() as conn:
conn.execute(table_name.insert(), data)
def write_pandas(self, df: pd.DataFrame, table_name: str) -> None:
df.to_sql(table_name, self.engine, if_exists='append', index=False)
def execute_query(self, query: str) -> None:
with self.engine.connect() as conn:
conn.execute(text(query))
def execute_query(self, query: str) -> Union[int, List[Dict[str, Any]]]:
"""
Execute an INSERT, UPDATE, or SELECT query.
Returns the number of rows affected for INSERT/UPDATE,
or the result of the SELECT query as a list of dictionaries.
"""
with self.engine.connect() as conn:
result: ResultProxy = conn.execute(text(query))
if query.strip().lower().startswith("select"):
return [dict(row) for row in result.fetchall()]
else:
return result.rowcount
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
raise NotImplementedError("get_columns method must be implemented by subclasses")
# Subclass PostgresqlDataAccess
class PostgresqlDataAccess(DataAccess):
def __init__(self, conn_str: str):
super().__init__(conn_str)
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
query = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = '{table_name}';
"""
with self.engine.connect() as conn:
result = conn.execute(text(query))
return [(row["column_name"], row["data_type"]) for row in result.fetchall()]
# Subclass SparkDataAccess
class SparkDataAccess:
def __init__(self, spark: SparkSession):
self.spark = spark
def read_spark(self, table_name: Optional[str] = None, query: Optional[str] = None) -> SparkDataFrame:
if (table_name is None and query is None) or (table_name is not None and query is not None):
raise ValueError("You must provide either a table name or a query, but not both.")
if table_name:
return self.spark.read.table(table_name)
if query:
return self.spark.sql(query)
def write_spark(self, table_name: str, df: SparkDataFrame, mode: str = "append") -> None:
df.write.mode(mode).saveAsTable(table_name)
# Subclass JdbcSparkDataAccess
class JdbcSparkDataAccess(SparkDataAccess):
def __init__(self, spark: SparkSession, jdbc_url: str, conn_props: Dict[str, Any]):
super().__init__(spark)
self.jdbc_url = jdbc_url
self.conn_props = conn_props
def read_spark(self, table_name: Optional[str] = None, query: Optional[str] = None) -> SparkDataFrame:
if (table_name is None and query is None) or (table_name is not None and query is not None):
raise ValueError("You must provide either a table name or a query, but not both.")
if table_name:
return self.spark.read.jdbc(self.jdbc_url, table_name, properties=self.conn_props)
if query:
return self.spark.sql(query)
def write_spark(self, table_name: str, df: SparkDataFrame, mode: str = "append") -> None:
df.write.jdbc(self.jdbc_url, table_name, mode=mode, properties=self.conn_props)
# Subclass PostgresqlSparkDataAccess
class PostgresqlSparkDataAccess(JdbcSparkDataAccess):
def __init__(self, spark: SparkSession, jdbc_url: str, conn_props: Dict[str, Any]):
# Specify PostgreSQL JDBC driver class name
conn_props["driver"] = "org.postgresql.Driver"
super().__init__(spark, jdbc_url, conn_props)
# Example Usage:
if __name__ == "__main__":
# Example usage for SQLAlchemy-based operations
conn_str = "postgresql://username:password@host:port/database"
pg_access = PostgresqlDataAccess(conn_str)
# Example: Reading data as a dictionary
query = "SELECT * FROM users"
result_dict = pg_access.read_dict(query)
print(result_dict)
# Example usage for Spark-based operations
spark = SparkSession.builder.appName("example").getOrCreate()
jdbc_url = "jdbc:postgresql://host:port/database"
conn_props = {"user": "your_username", "password": "your_password"}
# Create an instance of PostgresqlSparkDataAccess
postgres_spark_access = PostgresqlSparkDataAccess(spark, jdbc_url, conn_props)
# Example: Read from Spark table or SQL query
df_table = postgres_spark_access.read_spark(table_name="my_table")
df_query = postgres_spark_access.read_spark(query="SELECT * FROM my_table WHERE id > 100")
# Example: Write to a Spark table
postgres_spark_access.write_spark("my_table", df_query, mode="append")
class DeltaSparkDataAccess(SparkDataAccess):
def __init__(self, spark):
"""
Initialize with a SparkSession object.
"""
super().__init__(spark)
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame:
"""
Read data from a Delta table or execute a SQL query.
:param table_name: The path or table name of the Delta table (default None)
:param sql: The SQL query to read data from (default None)
:param options: Additional options for reading the data
:return: Spark DataFrame
"""
return super().read_spark(table_name=table_name, sql=sql, **options)
def write_spark(self, df: DataFrame, destination: str, mode: str = "overwrite", **options):
"""
Write data to a Delta table.
:param df: Spark DataFrame to be written
:param destination: The path or table name of the Delta table
:param mode: Write mode (e.g., 'overwrite', 'append')
:param options: Additional options for the write method
"""
super().write_spark(df, destination, format="delta", mode=mode, **options)
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
"""
Get the column names and their types for a given Delta table.
:param table_name: The path or table name of the Delta table
:return: List of tuples containing column names and types
"""
# Load the Delta table and get the schema
df = self.read_spark(table_name=table_name)
return [(field.name, field.dataType.simpleString()) for field in df.schema.fields]
class SparkJdbcDataAccess(SparkDataAccess):
def __init__(self, spark, jdbc_url: str, db_table: str, user: str, password: str, driver: str):
"""
Initialize with a SparkSession object and JDBC connection details.
:param spark: SparkSession object
:param jdbc_url: JDBC URL for the database connection
:param db_table: The table name in the database
:param user: Database username
:param password: Database password
:param driver: The JDBC driver class name
"""
super().__init__(spark)
self.jdbc_url = jdbc_url
self.db_table = db_table
self.user = user
self.password = password
self.driver = driver
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame:
"""
Read data from a database table using JDBC or execute a SQL query.
:param table_name: The table name in the database (default None)
:param sql: The SQL query to read data from (default None)
:param options: Additional options for the read method
:return: Spark DataFrame
"""
if sql is not None:
# If it's an SQL query, use the query directly with spark.read.jdbc
options.update({
"url": self.jdbc_url,
"query": sql,
"user": self.user,
"password": self.password,
"driver": self.driver # Specify the JDBC driver class name
})
return super().read_spark(table_name=None, sql=sql, **options)
else:
# Treat as a table name in the database
options.update({
"url": self.jdbc_url,
"dbtable": table_name,
"user": self.user,
"password": self.password,
"driver": self.driver # Specify the JDBC driver class name
})
return super().read_spark(table_name=table_name, **options)
def write_spark(self, df: DataFrame, mode: str = "overwrite", **options):
"""
Write data to a database table using JDBC.
:param df: Spark DataFrame to be written
:param mode: Write mode (e.g., 'overwrite', 'append')
:param options: Additional options for the write method
"""
options.update({
"url": self.jdbc_url,
"dbtable": self.db_table,
"user": self.user,
"password": self.password,
"driver": self.driver # Specify the JDBC driver class name
})
super().write_spark(df, self.db_table, format="jdbc", mode=mode, **options)
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
"""
Get the column names and their types for a given JDBC table.
This is a placeholder to be implemented in subclasses for specific databases.
:param table_name: The table name
:return: List of tuples containing column names and types
"""
raise NotImplementedError("This method should be implemented in subclasses.")
class PostgresqlSparkDataAccess(SparkJdbcDataAccess):
def __init__(self, spark, jdbc_url: str, db_table: str, user: str, password: str):
"""
Initialize with a SparkSession object and PostgreSQL connection details.
This class specifies the PostgreSQL JDBC driver class name.
:param spark: SparkSession object
:param jdbc_url: JDBC URL for PostgreSQL connection
:param db_table: The table name in the PostgreSQL database
:param user: Database username
:param password: Database password
"""
# Specify the PostgreSQL driver class name
driver = "org.postgresql.Driver"
super().__init__(spark, jdbc_url, db_table, user, password, driver)
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
"""
Get the column names and their types for a given PostgreSQL table.
:param table_name: The table name
:return: List of tuples containing column names and types
"""
query = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = '{table_name}'
"""
df = self.read_spark(sql=query)
# Extract columns and types
return [(row['column_name'], row['data_type']) for row in df.collect()]
from pyspark.sql import DataFrame, Row
from typing import List, Tuple
class SparkDataAccess:
def __init__(self, spark):
"""
Initialize with a SparkSession object.
:param spark: SparkSession object
"""
self.spark = spark
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame:
"""
Read data from a Spark data source, which can be a table or a SQL query.
:param table_name: The table name or path to read data from (default None)
:param sql: The SQL query to read data from (default None)
:param options: Additional options for reading the data
:return: Spark DataFrame
:raises ValueError: If both or neither of table_name and sql are provided
"""
if (table_name is not None and sql is not None) or (table_name is None and sql is None):
raise ValueError("Exactly one of table_name or sql must be provided.")
if sql is not None:
# Execute SQL query and return DataFrame
df = self.spark.sql(sql)
else:
# Treat table_name as a table name or file path and read data
df = self.spark.read.options(**options).table(table_name)
return df
def write_spark(self, df: DataFrame, destination: str, format: str, mode: str = "overwrite", **options):
"""
Write data to a Spark data source.
:param df: Spark DataFrame to be written
:param destination: The path or table name where the data will be written
:param format: The format of the data (e.g., 'delta', 'parquet', etc.)
:param mode: Write mode (e.g., 'overwrite', 'append')
:param options: Additional options for the write method
"""
df.write.format(format).mode(mode).options(**options).save(destination)
def execute_query(self, query: str) -> DataFrame:
"""
Execute a Spark SQL query. This can be a SELECT, UPDATE, DELETE, or INSERT statement.
:param query: The SQL query string
:return: If the query is a SELECT, returns a DataFrame.
For other queries (UPDATE, DELETE, INSERT), returns the number of affected rows.
"""
# Execute the query
result = self.spark.sql(query)
# Check if it's a SELECT query by analyzing the result schema
if result.columns: # If there are columns, it's a SELECT query
return result
else: # Otherwise, it's an update, delete, or insert query
# Return the number of affected rows
affected_rows = result.count() # This will give the count of affected rows
return affected_rows
def get_columns(self, table_name: str) -> List[Tuple[str, str]]:
"""
Get the column names and their types for a given table name.
This is a placeholder method to be implemented in subclasses.
:param table_name: The table name or path
:return: List of tuples containing column names and types
"""
raise NotImplementedError("This method should be implemented in subclasses.")
from typing import List, Dict, Any
import pandas as pd
from pyspark.sql import SparkSession, DataFrame as SparkDataFrame
# Subclass SparkDataAccess
class SparkDataAccess:
def __init__(self, spark: SparkSession):
self.spark = spark
def read_dict(self, query: str) -> List[Dict[str, Any]]:
# Execute the query using Spark SQL and return the result as a list of dictionaries
df: SparkDataFrame = self.spark.sql(query)
return [row.asDict() for row in df.collect()]
def read_pandas(self, query: str) -> pd.DataFrame:
# Execute the query using Spark SQL and convert the result to a Pandas DataFrame
df: SparkDataFrame = self.spark.sql(query)
return df.toPandas()
def write_dict(self, data: List[Dict[str, Any]], table_name: str) -> None:
# Convert the list of dictionaries to a Spark DataFrame and write to the specified table
df: SparkDataFrame = self.spark.createDataFrame(data)
df.write.mode("append").saveAsTable(table_name)
def write_pandas(self, df: pd.DataFrame, table_name: str) -> None:
# Convert the Pandas DataFrame to a Spark DataFrame and write to the specified table
spark_df: SparkDataFrame = self.spark.createDataFrame(df)
spark_df.write.mode("append").saveAsTable(table_name)
# Example Usage:
if __name__ == "__main__":
# Initialize Spark session
spark = SparkSession.builder.appName("example").getOrCreate()
# Create an instance of SparkDataAccess
spark_access = SparkDataAccess(spark)
# Example: Read data as a list of dictionaries
query = "SELECT * FROM my_table"
result_dict = spark_access.read_dict(query)
print(result_dict)
# Example: Read data as a Pandas DataFrame
result_pandas = spark_access.read_pandas(query)
print(result_pandas)
# Example: Write a list of dictionaries to a Spark table
data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
spark_access.write_dict(data, "my_table")
# Example: Write a Pandas DataFrame to a Spark table
df = pd.DataFrame(data)
spark_access.write_pandas(df, "my_table")
from pyspark.sql import SparkSession
import re
class SparkDataAccess:
def __init__(self, spark: SparkSession, jdbc_url: str, jdbc_properties: dict):
self.spark = spark
self.jdbc_url = jdbc_url
self.jdbc_properties = jdbc_properties
def read_spark(self, table_name=None, query=None):
if table_name:
return self._load_table(table_name)
elif query:
return self._execute_query(query)
else:
raise ValueError("Either table_name or query must be provided.")
def _load_table(self, table_name):
df = self.spark.read.jdbc(self.jdbc_url, table_name, properties=self.jdbc_properties)
df.createOrReplaceTempView(table_name)
return df
def _execute_query(self, query):
# Extract table names from the query, including handling WITH clauses
table_names = self._extract_table_names(query)
# Create temp views for the identified table names
for table_name in table_names:
self._load_table(table_name)
# Execute the query
return self.spark.sql(query)
def _extract_table_names(self, query):
# Simple regex to extract table names after FROM or JOIN keywords
pattern = re.compile(r"(FROM|JOIN)\s+(\w+)", re.IGNORECASE)
matches = pattern.findall(query)
table_names = set(match[1] for match in matches)
# Handle WITH clause by extracting table names from CTEs
with_pattern = re.compile(r"WITH\s+(\w+)\s+AS\s*\(", re.IGNORECASE)
with_matches = with_pattern.findall(query)
table_names.update(with_matches)
return table_names
def main():
# Initialize SparkSession
spark = SparkSession.builder \
.appName("Spark Data Access Example") \
.getOrCreate()
# Example for DeltaSparkDataAccess
delta_accessor = DeltaSparkDataAccess(spark)
delta_df = delta_accessor.read_spark(table_name="delta_table")
print("Delta Table Columns:", delta_accessor.get_columns("delta_table"))
delta_accessor.write_spark(delta_df, "output_delta_table", format="delta")
# Example for PostgresqlSparkDataAccess
jdbc_url = "jdbc:postgresql://localhost:5432/mydatabase"
user = "myuser"
password = "mypassword"
postgresql_accessor = PostgresqlSparkDataAccess(spark, jdbc_url, "my_table", user, password)
pg_df = postgresql_accessor.read_spark(table_name="my_table")
print("PostgreSQL Table Columns:", postgresql_accessor.get_columns("my_table"))
postgresql_accessor.write_spark(pg_df, mode="append")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment