Last active
August 18, 2024 09:56
-
-
Save angusdev/9ca9285d14e4d3acaf2259618215eff3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Any, Tuple, Optional | |
from sqlalchemy import create_engine, text | |
from sqlalchemy.engine import Engine, ResultProxy | |
import pandas as pd | |
from pyspark.sql import SparkSession, DataFrame as SparkDataFrame | |
# Base class DataAccess using SQLAlchemy | |
class DataAccess: | |
def __init__(self, conn_str: str): | |
self.engine: Engine = create_engine(conn_str) | |
def read_dict(self, query: str) -> List[Dict[str, Any]]: | |
with self.engine.connect() as conn: | |
result: ResultProxy = conn.execute(text(query)) | |
return [dict(row) for row in result.fetchall()] | |
def read_pandas(self, query: str) -> pd.DataFrame: | |
with self.engine.connect() as conn: | |
return pd.read_sql(query, conn) | |
def write_dict(self, data: List[Dict[str, Any]], table_name: str) -> None: | |
with self.engine.connect() as conn: | |
conn.execute(table_name.insert(), data) | |
def write_pandas(self, df: pd.DataFrame, table_name: str) -> None: | |
df.to_sql(table_name, self.engine, if_exists='append', index=False) | |
def execute_query(self, query: str) -> None: | |
with self.engine.connect() as conn: | |
conn.execute(text(query)) | |
def execute_query(self, query: str) -> Union[int, List[Dict[str, Any]]]: | |
""" | |
Execute an INSERT, UPDATE, or SELECT query. | |
Returns the number of rows affected for INSERT/UPDATE, | |
or the result of the SELECT query as a list of dictionaries. | |
""" | |
with self.engine.connect() as conn: | |
result: ResultProxy = conn.execute(text(query)) | |
if query.strip().lower().startswith("select"): | |
return [dict(row) for row in result.fetchall()] | |
else: | |
return result.rowcount | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
raise NotImplementedError("get_columns method must be implemented by subclasses") | |
# Subclass PostgresqlDataAccess | |
class PostgresqlDataAccess(DataAccess): | |
def __init__(self, conn_str: str): | |
super().__init__(conn_str) | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
query = f""" | |
SELECT column_name, data_type | |
FROM information_schema.columns | |
WHERE table_name = '{table_name}'; | |
""" | |
with self.engine.connect() as conn: | |
result = conn.execute(text(query)) | |
return [(row["column_name"], row["data_type"]) for row in result.fetchall()] | |
# Subclass SparkDataAccess | |
class SparkDataAccess: | |
def __init__(self, spark: SparkSession): | |
self.spark = spark | |
def read_spark(self, table_name: Optional[str] = None, query: Optional[str] = None) -> SparkDataFrame: | |
if (table_name is None and query is None) or (table_name is not None and query is not None): | |
raise ValueError("You must provide either a table name or a query, but not both.") | |
if table_name: | |
return self.spark.read.table(table_name) | |
if query: | |
return self.spark.sql(query) | |
def write_spark(self, table_name: str, df: SparkDataFrame, mode: str = "append") -> None: | |
df.write.mode(mode).saveAsTable(table_name) | |
# Subclass JdbcSparkDataAccess | |
class JdbcSparkDataAccess(SparkDataAccess): | |
def __init__(self, spark: SparkSession, jdbc_url: str, conn_props: Dict[str, Any]): | |
super().__init__(spark) | |
self.jdbc_url = jdbc_url | |
self.conn_props = conn_props | |
def read_spark(self, table_name: Optional[str] = None, query: Optional[str] = None) -> SparkDataFrame: | |
if (table_name is None and query is None) or (table_name is not None and query is not None): | |
raise ValueError("You must provide either a table name or a query, but not both.") | |
if table_name: | |
return self.spark.read.jdbc(self.jdbc_url, table_name, properties=self.conn_props) | |
if query: | |
return self.spark.sql(query) | |
def write_spark(self, table_name: str, df: SparkDataFrame, mode: str = "append") -> None: | |
df.write.jdbc(self.jdbc_url, table_name, mode=mode, properties=self.conn_props) | |
# Subclass PostgresqlSparkDataAccess | |
class PostgresqlSparkDataAccess(JdbcSparkDataAccess): | |
def __init__(self, spark: SparkSession, jdbc_url: str, conn_props: Dict[str, Any]): | |
# Specify PostgreSQL JDBC driver class name | |
conn_props["driver"] = "org.postgresql.Driver" | |
super().__init__(spark, jdbc_url, conn_props) | |
# Example Usage: | |
if __name__ == "__main__": | |
# Example usage for SQLAlchemy-based operations | |
conn_str = "postgresql://username:password@host:port/database" | |
pg_access = PostgresqlDataAccess(conn_str) | |
# Example: Reading data as a dictionary | |
query = "SELECT * FROM users" | |
result_dict = pg_access.read_dict(query) | |
print(result_dict) | |
# Example usage for Spark-based operations | |
spark = SparkSession.builder.appName("example").getOrCreate() | |
jdbc_url = "jdbc:postgresql://host:port/database" | |
conn_props = {"user": "your_username", "password": "your_password"} | |
# Create an instance of PostgresqlSparkDataAccess | |
postgres_spark_access = PostgresqlSparkDataAccess(spark, jdbc_url, conn_props) | |
# Example: Read from Spark table or SQL query | |
df_table = postgres_spark_access.read_spark(table_name="my_table") | |
df_query = postgres_spark_access.read_spark(query="SELECT * FROM my_table WHERE id > 100") | |
# Example: Write to a Spark table | |
postgres_spark_access.write_spark("my_table", df_query, mode="append") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeltaSparkDataAccess(SparkDataAccess): | |
def __init__(self, spark): | |
""" | |
Initialize with a SparkSession object. | |
""" | |
super().__init__(spark) | |
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame: | |
""" | |
Read data from a Delta table or execute a SQL query. | |
:param table_name: The path or table name of the Delta table (default None) | |
:param sql: The SQL query to read data from (default None) | |
:param options: Additional options for reading the data | |
:return: Spark DataFrame | |
""" | |
return super().read_spark(table_name=table_name, sql=sql, **options) | |
def write_spark(self, df: DataFrame, destination: str, mode: str = "overwrite", **options): | |
""" | |
Write data to a Delta table. | |
:param df: Spark DataFrame to be written | |
:param destination: The path or table name of the Delta table | |
:param mode: Write mode (e.g., 'overwrite', 'append') | |
:param options: Additional options for the write method | |
""" | |
super().write_spark(df, destination, format="delta", mode=mode, **options) | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
""" | |
Get the column names and their types for a given Delta table. | |
:param table_name: The path or table name of the Delta table | |
:return: List of tuples containing column names and types | |
""" | |
# Load the Delta table and get the schema | |
df = self.read_spark(table_name=table_name) | |
return [(field.name, field.dataType.simpleString()) for field in df.schema.fields] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SparkJdbcDataAccess(SparkDataAccess): | |
def __init__(self, spark, jdbc_url: str, db_table: str, user: str, password: str, driver: str): | |
""" | |
Initialize with a SparkSession object and JDBC connection details. | |
:param spark: SparkSession object | |
:param jdbc_url: JDBC URL for the database connection | |
:param db_table: The table name in the database | |
:param user: Database username | |
:param password: Database password | |
:param driver: The JDBC driver class name | |
""" | |
super().__init__(spark) | |
self.jdbc_url = jdbc_url | |
self.db_table = db_table | |
self.user = user | |
self.password = password | |
self.driver = driver | |
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame: | |
""" | |
Read data from a database table using JDBC or execute a SQL query. | |
:param table_name: The table name in the database (default None) | |
:param sql: The SQL query to read data from (default None) | |
:param options: Additional options for the read method | |
:return: Spark DataFrame | |
""" | |
if sql is not None: | |
# If it's an SQL query, use the query directly with spark.read.jdbc | |
options.update({ | |
"url": self.jdbc_url, | |
"query": sql, | |
"user": self.user, | |
"password": self.password, | |
"driver": self.driver # Specify the JDBC driver class name | |
}) | |
return super().read_spark(table_name=None, sql=sql, **options) | |
else: | |
# Treat as a table name in the database | |
options.update({ | |
"url": self.jdbc_url, | |
"dbtable": table_name, | |
"user": self.user, | |
"password": self.password, | |
"driver": self.driver # Specify the JDBC driver class name | |
}) | |
return super().read_spark(table_name=table_name, **options) | |
def write_spark(self, df: DataFrame, mode: str = "overwrite", **options): | |
""" | |
Write data to a database table using JDBC. | |
:param df: Spark DataFrame to be written | |
:param mode: Write mode (e.g., 'overwrite', 'append') | |
:param options: Additional options for the write method | |
""" | |
options.update({ | |
"url": self.jdbc_url, | |
"dbtable": self.db_table, | |
"user": self.user, | |
"password": self.password, | |
"driver": self.driver # Specify the JDBC driver class name | |
}) | |
super().write_spark(df, self.db_table, format="jdbc", mode=mode, **options) | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
""" | |
Get the column names and their types for a given JDBC table. | |
This is a placeholder to be implemented in subclasses for specific databases. | |
:param table_name: The table name | |
:return: List of tuples containing column names and types | |
""" | |
raise NotImplementedError("This method should be implemented in subclasses.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PostgresqlSparkDataAccess(SparkJdbcDataAccess): | |
def __init__(self, spark, jdbc_url: str, db_table: str, user: str, password: str): | |
""" | |
Initialize with a SparkSession object and PostgreSQL connection details. | |
This class specifies the PostgreSQL JDBC driver class name. | |
:param spark: SparkSession object | |
:param jdbc_url: JDBC URL for PostgreSQL connection | |
:param db_table: The table name in the PostgreSQL database | |
:param user: Database username | |
:param password: Database password | |
""" | |
# Specify the PostgreSQL driver class name | |
driver = "org.postgresql.Driver" | |
super().__init__(spark, jdbc_url, db_table, user, password, driver) | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
""" | |
Get the column names and their types for a given PostgreSQL table. | |
:param table_name: The table name | |
:return: List of tuples containing column names and types | |
""" | |
query = f""" | |
SELECT column_name, data_type | |
FROM information_schema.columns | |
WHERE table_name = '{table_name}' | |
""" | |
df = self.read_spark(sql=query) | |
# Extract columns and types | |
return [(row['column_name'], row['data_type']) for row in df.collect()] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import DataFrame, Row | |
from typing import List, Tuple | |
class SparkDataAccess: | |
def __init__(self, spark): | |
""" | |
Initialize with a SparkSession object. | |
:param spark: SparkSession object | |
""" | |
self.spark = spark | |
def read_spark(self, table_name: str = None, sql: str = None, **options) -> DataFrame: | |
""" | |
Read data from a Spark data source, which can be a table or a SQL query. | |
:param table_name: The table name or path to read data from (default None) | |
:param sql: The SQL query to read data from (default None) | |
:param options: Additional options for reading the data | |
:return: Spark DataFrame | |
:raises ValueError: If both or neither of table_name and sql are provided | |
""" | |
if (table_name is not None and sql is not None) or (table_name is None and sql is None): | |
raise ValueError("Exactly one of table_name or sql must be provided.") | |
if sql is not None: | |
# Execute SQL query and return DataFrame | |
df = self.spark.sql(sql) | |
else: | |
# Treat table_name as a table name or file path and read data | |
df = self.spark.read.options(**options).table(table_name) | |
return df | |
def write_spark(self, df: DataFrame, destination: str, format: str, mode: str = "overwrite", **options): | |
""" | |
Write data to a Spark data source. | |
:param df: Spark DataFrame to be written | |
:param destination: The path or table name where the data will be written | |
:param format: The format of the data (e.g., 'delta', 'parquet', etc.) | |
:param mode: Write mode (e.g., 'overwrite', 'append') | |
:param options: Additional options for the write method | |
""" | |
df.write.format(format).mode(mode).options(**options).save(destination) | |
def execute_query(self, query: str) -> DataFrame: | |
""" | |
Execute a Spark SQL query. This can be a SELECT, UPDATE, DELETE, or INSERT statement. | |
:param query: The SQL query string | |
:return: If the query is a SELECT, returns a DataFrame. | |
For other queries (UPDATE, DELETE, INSERT), returns the number of affected rows. | |
""" | |
# Execute the query | |
result = self.spark.sql(query) | |
# Check if it's a SELECT query by analyzing the result schema | |
if result.columns: # If there are columns, it's a SELECT query | |
return result | |
else: # Otherwise, it's an update, delete, or insert query | |
# Return the number of affected rows | |
affected_rows = result.count() # This will give the count of affected rows | |
return affected_rows | |
def get_columns(self, table_name: str) -> List[Tuple[str, str]]: | |
""" | |
Get the column names and their types for a given table name. | |
This is a placeholder method to be implemented in subclasses. | |
:param table_name: The table name or path | |
:return: List of tuples containing column names and types | |
""" | |
raise NotImplementedError("This method should be implemented in subclasses.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Any | |
import pandas as pd | |
from pyspark.sql import SparkSession, DataFrame as SparkDataFrame | |
# Subclass SparkDataAccess | |
class SparkDataAccess: | |
def __init__(self, spark: SparkSession): | |
self.spark = spark | |
def read_dict(self, query: str) -> List[Dict[str, Any]]: | |
# Execute the query using Spark SQL and return the result as a list of dictionaries | |
df: SparkDataFrame = self.spark.sql(query) | |
return [row.asDict() for row in df.collect()] | |
def read_pandas(self, query: str) -> pd.DataFrame: | |
# Execute the query using Spark SQL and convert the result to a Pandas DataFrame | |
df: SparkDataFrame = self.spark.sql(query) | |
return df.toPandas() | |
def write_dict(self, data: List[Dict[str, Any]], table_name: str) -> None: | |
# Convert the list of dictionaries to a Spark DataFrame and write to the specified table | |
df: SparkDataFrame = self.spark.createDataFrame(data) | |
df.write.mode("append").saveAsTable(table_name) | |
def write_pandas(self, df: pd.DataFrame, table_name: str) -> None: | |
# Convert the Pandas DataFrame to a Spark DataFrame and write to the specified table | |
spark_df: SparkDataFrame = self.spark.createDataFrame(df) | |
spark_df.write.mode("append").saveAsTable(table_name) | |
# Example Usage: | |
if __name__ == "__main__": | |
# Initialize Spark session | |
spark = SparkSession.builder.appName("example").getOrCreate() | |
# Create an instance of SparkDataAccess | |
spark_access = SparkDataAccess(spark) | |
# Example: Read data as a list of dictionaries | |
query = "SELECT * FROM my_table" | |
result_dict = spark_access.read_dict(query) | |
print(result_dict) | |
# Example: Read data as a Pandas DataFrame | |
result_pandas = spark_access.read_pandas(query) | |
print(result_pandas) | |
# Example: Write a list of dictionaries to a Spark table | |
data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] | |
spark_access.write_dict(data, "my_table") | |
# Example: Write a Pandas DataFrame to a Spark table | |
df = pd.DataFrame(data) | |
spark_access.write_pandas(df, "my_table") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
import re | |
class SparkDataAccess: | |
def __init__(self, spark: SparkSession, jdbc_url: str, jdbc_properties: dict): | |
self.spark = spark | |
self.jdbc_url = jdbc_url | |
self.jdbc_properties = jdbc_properties | |
def read_spark(self, table_name=None, query=None): | |
if table_name: | |
return self._load_table(table_name) | |
elif query: | |
return self._execute_query(query) | |
else: | |
raise ValueError("Either table_name or query must be provided.") | |
def _load_table(self, table_name): | |
df = self.spark.read.jdbc(self.jdbc_url, table_name, properties=self.jdbc_properties) | |
df.createOrReplaceTempView(table_name) | |
return df | |
def _execute_query(self, query): | |
# Extract table names from the query, including handling WITH clauses | |
table_names = self._extract_table_names(query) | |
# Create temp views for the identified table names | |
for table_name in table_names: | |
self._load_table(table_name) | |
# Execute the query | |
return self.spark.sql(query) | |
def _extract_table_names(self, query): | |
# Simple regex to extract table names after FROM or JOIN keywords | |
pattern = re.compile(r"(FROM|JOIN)\s+(\w+)", re.IGNORECASE) | |
matches = pattern.findall(query) | |
table_names = set(match[1] for match in matches) | |
# Handle WITH clause by extracting table names from CTEs | |
with_pattern = re.compile(r"WITH\s+(\w+)\s+AS\s*\(", re.IGNORECASE) | |
with_matches = with_pattern.findall(query) | |
table_names.update(with_matches) | |
return table_names |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
# Initialize SparkSession | |
spark = SparkSession.builder \ | |
.appName("Spark Data Access Example") \ | |
.getOrCreate() | |
# Example for DeltaSparkDataAccess | |
delta_accessor = DeltaSparkDataAccess(spark) | |
delta_df = delta_accessor.read_spark(table_name="delta_table") | |
print("Delta Table Columns:", delta_accessor.get_columns("delta_table")) | |
delta_accessor.write_spark(delta_df, "output_delta_table", format="delta") | |
# Example for PostgresqlSparkDataAccess | |
jdbc_url = "jdbc:postgresql://localhost:5432/mydatabase" | |
user = "myuser" | |
password = "mypassword" | |
postgresql_accessor = PostgresqlSparkDataAccess(spark, jdbc_url, "my_table", user, password) | |
pg_df = postgresql_accessor.read_spark(table_name="my_table") | |
print("PostgreSQL Table Columns:", postgresql_accessor.get_columns("my_table")) | |
postgresql_accessor.write_spark(pg_df, mode="append") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment