Last active
May 10, 2025 17:10
-
-
Save igorvoltaic/28cdb770b64af7d095acb06e25b24772 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script to create external tables in Greenplum based on Parquet schemas from S3 | |
""" | |
import logging | |
import sys | |
from enum import Enum | |
from pathlib import PurePosixPath | |
from typing import List, Set | |
from urllib.parse import quote_plus | |
import psycopg2 | |
from boto3 import Session | |
from mypy_boto3_s3.client import S3Client | |
from mypy_boto3_s3.type_defs import ObjectTypeDef | |
from pyarrow import parquet | |
from pyarrow.fs import S3FileSystem | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
stream=sys.stdout | |
) | |
class GreenplumConnection: | |
"""Manages connection and transactions with Greenplum database""" | |
def __init__(self, user: str, password: str, host: str, dbname: str, connect_timeout: int = 10): | |
self.user = user | |
self.password = password | |
self.host = host | |
self.dbname = dbname | |
self.connect_timeout = connect_timeout | |
self.conn = None | |
def __enter__(self): | |
self.connect() | |
return self.conn | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.conn: | |
self.conn.close() | |
logging.info('Greenplum connection closed') | |
def connect(self): | |
"""Establish database connection""" | |
try: | |
self.conn = psycopg2.connect( | |
dbname=self.dbname, | |
user=self.user, | |
password=self.password, | |
host=self.host, | |
connect_timeout=self.connect_timeout | |
) | |
logging.info('Successfully connected to Greenplum') | |
except psycopg2.Error as e: | |
logging.error('Failed to connect to Greenplum: %s', e) | |
raise | |
class DataTypeMapping(Enum): | |
"""Maps Parquet data types to Greenplum equivalents""" | |
DECIMAL = 'numeric' | |
STRING = 'varchar' | |
TIMESTAMP_NS = 'timestamp' | |
BINARY = 'bytea' | |
INT64 = 'bigint' | |
INT32 = 'int' | |
INT16 = 'smallint' | |
DATE32_DAY = 'date' | |
@classmethod | |
def _get_type_mapping(cls): | |
return { | |
'decimal': cls.DECIMAL, | |
'string': cls.STRING, | |
'timestamp[ns]': cls.TIMESTAMP_NS, | |
'binary': cls.BINARY, | |
'int64': cls.INT64, | |
'int32': cls.INT32, | |
'int16': cls.INT16, | |
'date32[day]': cls.DATE32_DAY | |
} | |
@classmethod | |
def get_gp_type(cls, parquet_type: str) -> str: | |
"""Get Greenplum type for given Parquet type""" | |
return cls._get_type_mapping().get( | |
parquet_type.lower(), | |
cls.STRING | |
).value | |
@classmethod | |
def has_mapping_for(cls, parquet_type: str) -> bool: | |
"""Check if type mapping exists""" | |
return parquet_type.lower() in cls._get_type_mapping() | |
class S3ClientManager: | |
"""Manages S3 connections and file operations""" | |
def __init__(self, access_key: str, secret_key: str, endpoint: str): | |
self.access_key = access_key | |
self.secret_key = secret_key | |
self.endpoint = endpoint | |
self.client: S3Client | |
def __enter__(self): | |
self.connect() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
pass # The client doesn't need explicit closing | |
def connect(self): | |
"""Create and configure S3 client""" | |
try: | |
self.client = Session().client( | |
service_name='s3', | |
endpoint_url=f'https://{self.endpoint}', | |
aws_access_key_id=self.access_key, | |
aws_secret_access_key=self.secret_key | |
) | |
logging.info('Successfully connected to S3') | |
except Exception as e: | |
logging.error('Failed to connect to S3: %s', e) | |
raise | |
def list_objects(self, bucket: str, prefix: str) -> List[ObjectTypeDef]: | |
"""Retrieve objects from S3 with pagination""" | |
try: | |
paginator = self.client.get_paginator('list_objects_v2') | |
pages = paginator.paginate(Bucket=bucket, Prefix=prefix) | |
return [ | |
obj | |
for page in pages | |
for obj in page.get('Contents', []) | |
if obj.get('Size', 0) > 0 | |
] | |
except Exception as e: | |
logging.error('Failed to retrieve S3 objects: %s', e) | |
raise | |
def get_s3_folders(self, bucket: str, base_folder: str) -> Set[str]: | |
"""Extract unique folder names from S3 keys""" | |
objects = self.list_objects(bucket, base_folder) | |
return {self._extract_dir_name(obj['Key'], base_folder) for obj in objects if obj.get('Key')} #type:ignore | |
@staticmethod | |
def _extract_dir_name(s3_key: str, base_folder: str) -> str: | |
"""Parse directory name from S3 key""" | |
if not base_folder.endswith('/'): | |
base_folder += '/' | |
if not s3_key.startswith(base_folder): | |
raise ValueError(f"S3 key {s3_key} does not start with base folder {base_folder}") | |
return PurePosixPath(s3_key[len(base_folder):]).parts[0] | |
class ParquetSchemaParser: | |
"""Handles Parquet schema parsing and SQL generation""" | |
@staticmethod | |
def generate_column_definitions(schema: parquet.ParquetSchema) -> str: | |
"""Convert Parquet schema to SQL column definitions""" | |
columns = [] | |
for field in schema: | |
pyarrow_type = str(field.type).lower() | |
gp_type = DataTypeMapping.get_gp_type(pyarrow_type) | |
if not DataTypeMapping.has_mapping_for(pyarrow_type): | |
logging.warning('Unmapped Parquet type: %s (column: %s), using varchar', | |
pyarrow_type, field.name) | |
columns.append(f"{field.name.lower()} {gp_type}") | |
return ',\n'.join(columns) | |
class ExternalTableCreator: | |
"""Orchestrates the external table creation process""" | |
BUCKET_CONFIG = { | |
'events': ('dwh', 'events/'), | |
} | |
def __init__(self, s3_client: S3ClientManager, gp_conn, entity_type: str, s3_mask: str, stg_owner: str, | |
postfix: str = '', partition_suffix: str = ''): | |
self.s3_client = s3_client | |
self.gp_conn = gp_conn | |
self.entity_type = entity_type.lower() | |
self.s3_mask = s3_mask.lower() | |
self.postfix = postfix | |
self.partition_suffix = partition_suffix | |
self.stg_owner = stg_owner | |
if self.entity_type not in self.BUCKET_CONFIG: | |
raise ValueError(f'Invalid entity type: {entity_type}') | |
self.bucket, self.base_folder = self.BUCKET_CONFIG[self.entity_type] | |
def create_external_tables(self): | |
"""Main method to create tables for all matching entities""" | |
for entity_name in self._get_entity_names(): | |
self._process_entity(entity_name) | |
def _get_entity_names(self) -> List[str]: | |
"""Determine entities to process based on S3 mask""" | |
if self.s3_mask == 'all': | |
return list(self.s3_client.get_s3_folders(self.bucket, self.base_folder)) | |
return [self.s3_mask] | |
def _process_entity(self, entity_name: str): | |
"""Process individual entity table creation""" | |
table_name = f"{entity_name.lower()}{self.postfix}" | |
logging.info('Processing table: %s', table_name) | |
try: | |
# Get latest Parquet file | |
latest_file = self._get_latest_parquet_file(entity_name) | |
if not latest_file: | |
return | |
# Generate SQL definitions | |
columns_sql = self._read_parquet_schema(latest_file) | |
create_sql = self._generate_create_sql(columns_sql, entity_name, table_name) | |
# Execute SQL statements | |
with self.gp_conn.cursor() as cursor: | |
cursor.execute(f'DROP EXTERNAL TABLE IF EXISTS stg.{table_name}') | |
cursor.execute(create_sql) | |
cursor.execute(f'ALTER TABLE stg.{table_name} OWNER TO {self.stg_owner};') | |
self.gp_conn.commit() | |
logging.info('Successfully created table: %s', table_name) | |
except Exception as e: | |
logging.error('Failed to process entity %s: %s', entity_name, e) | |
self.gp_conn.rollback() | |
raise | |
def _get_latest_parquet_file(self, entity_name: str) -> ObjectTypeDef: | |
"""Find most recent Parquet file for entity""" | |
objects = self.s3_client.list_objects( | |
self.bucket, f"{self.base_folder}{entity_name}/" | |
) | |
valid_objects = [obj for obj in objects if obj.get('Key') and obj.get('LastModified')] | |
if not valid_objects: | |
raise ValueError('No valid files found for mask: %s', entity_name) | |
return max(valid_objects, key=lambda x: x['LastModified']) # type:ignore | |
def _read_parquet_schema(self, s3_object: ObjectTypeDef) -> str: | |
"""Read schema from Parquet file on S3""" | |
fs = S3FileSystem( | |
access_key=self.s3_client.access_key, | |
secret_key=self.s3_client.secret_key, | |
endpoint_override=self.s3_client.endpoint, | |
scheme="https" | |
) | |
parquet_path = f"{self.bucket}/{s3_object['Key']}" # type:ignore | |
try: | |
parquet_file = parquet.ParquetFile(fs.open_input_file(parquet_path)) | |
return ParquetSchemaParser.generate_column_definitions(parquet_file.schema) | |
except Exception as e: | |
logging.error('Failed to read Parquet schema: %s', e) | |
raise | |
def _generate_create_sql(self, columns_sql: str, entity_name: str, table_name: str) -> str: | |
"""Generate CREATE EXTERNAL TABLE SQL statement""" | |
params = [ | |
'PROFILE=s3:parquet', | |
'COMPRESSION_CODEC=gz', | |
f'accesskey={quote_plus(self.s3_client.access_key)}', | |
f'secretkey={quote_plus(self.s3_client.secret_key)}', | |
f'endpoint={self.s3_client.endpoint}' | |
] | |
query = '&'.join(params) | |
s3_path = f"{self.base_folder}{entity_name}/" | |
if self.partition_suffix: | |
s3_path += f"{self.partition_suffix}/" | |
return f""" | |
CREATE EXTERNAL TABLE stg.{table_name} ( | |
{columns_sql} | |
) | |
LOCATION ( | |
'pxf://{self.bucket}/{s3_path}*?{query}' | |
) | |
ON ALL | |
FORMAT 'CUSTOM' (FORMATTER='pxfwritable_import') | |
ENCODING 'UTF8';""" | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description='Create external tables in Greenplum from S3 Parquet files') | |
parser.add_argument('--entity-type', required=True, choices=['events'], help='Type of entity to process') | |
parser.add_argument('--s3-endpoint', required=True, help='S3 endpoint') | |
parser.add_argument('--s3-mask', required=True, help='S3 object name or "all" to process all tables') | |
parser.add_argument('--aws-key', required=True, help='AWS access key ID') | |
parser.add_argument('--aws-secret', required=True, help='AWS secret access key') | |
parser.add_argument('--gp-user', required=True, help='Greenplum username') | |
parser.add_argument('--gp-password', required=True, help='Greenplum password') | |
parser.add_argument('--gp-dbname', required=True, help='Greenplum database name') | |
parser.add_argument('--gp-host', required=True, help='Greenplum database host') | |
parser.add_argument('--gp-stg-owner', required=True, help='Greenplum stg table owner') | |
args = parser.parse_args() | |
try: | |
with (S3ClientManager( | |
access_key=args.aws_key, | |
secret_key=args.aws_secret, | |
endpoint=args.s3_endpoint, | |
) as s3_client, | |
GreenplumConnection( | |
user=args.gp_user, | |
password=args.gp_password, | |
dbname=args.gp_dbname, | |
host=args.gp_host, | |
) as gp_conn): | |
creator = ExternalTableCreator( | |
s3_client=s3_client, | |
gp_conn=gp_conn, | |
entity_type=args.entity_type, | |
s3_mask=args.s3_mask, | |
stg_owner=args.stg_owner, | |
) | |
creator.create_external_tables() | |
except Exception as e: | |
logging.error('Fatal error occurred: %s', e) | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment