Skip to content

Instantly share code, notes, and snippets.

@igorvoltaic
Last active May 10, 2025 17:10
Show Gist options
  • Save igorvoltaic/28cdb770b64af7d095acb06e25b24772 to your computer and use it in GitHub Desktop.
Save igorvoltaic/28cdb770b64af7d095acb06e25b24772 to your computer and use it in GitHub Desktop.
"""
Script to create external tables in Greenplum based on Parquet schemas from S3
"""
import logging
import sys
from enum import Enum
from pathlib import PurePosixPath
from typing import List, Set
from urllib.parse import quote_plus
import psycopg2
from boto3 import Session
from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.type_defs import ObjectTypeDef
from pyarrow import parquet
from pyarrow.fs import S3FileSystem
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
class GreenplumConnection:
"""Manages connection and transactions with Greenplum database"""
def __init__(self, user: str, password: str, host: str, dbname: str, connect_timeout: int = 10):
self.user = user
self.password = password
self.host = host
self.dbname = dbname
self.connect_timeout = connect_timeout
self.conn = None
def __enter__(self):
self.connect()
return self.conn
def __exit__(self, exc_type, exc_val, exc_tb):
if self.conn:
self.conn.close()
logging.info('Greenplum connection closed')
def connect(self):
"""Establish database connection"""
try:
self.conn = psycopg2.connect(
dbname=self.dbname,
user=self.user,
password=self.password,
host=self.host,
connect_timeout=self.connect_timeout
)
logging.info('Successfully connected to Greenplum')
except psycopg2.Error as e:
logging.error('Failed to connect to Greenplum: %s', e)
raise
class DataTypeMapping(Enum):
"""Maps Parquet data types to Greenplum equivalents"""
DECIMAL = 'numeric'
STRING = 'varchar'
TIMESTAMP_NS = 'timestamp'
BINARY = 'bytea'
INT64 = 'bigint'
INT32 = 'int'
INT16 = 'smallint'
DATE32_DAY = 'date'
@classmethod
def _get_type_mapping(cls):
return {
'decimal': cls.DECIMAL,
'string': cls.STRING,
'timestamp[ns]': cls.TIMESTAMP_NS,
'binary': cls.BINARY,
'int64': cls.INT64,
'int32': cls.INT32,
'int16': cls.INT16,
'date32[day]': cls.DATE32_DAY
}
@classmethod
def get_gp_type(cls, parquet_type: str) -> str:
"""Get Greenplum type for given Parquet type"""
return cls._get_type_mapping().get(
parquet_type.lower(),
cls.STRING
).value
@classmethod
def has_mapping_for(cls, parquet_type: str) -> bool:
"""Check if type mapping exists"""
return parquet_type.lower() in cls._get_type_mapping()
class S3ClientManager:
"""Manages S3 connections and file operations"""
def __init__(self, access_key: str, secret_key: str, endpoint: str):
self.access_key = access_key
self.secret_key = secret_key
self.endpoint = endpoint
self.client: S3Client
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass # The client doesn't need explicit closing
def connect(self):
"""Create and configure S3 client"""
try:
self.client = Session().client(
service_name='s3',
endpoint_url=f'https://{self.endpoint}',
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key
)
logging.info('Successfully connected to S3')
except Exception as e:
logging.error('Failed to connect to S3: %s', e)
raise
def list_objects(self, bucket: str, prefix: str) -> List[ObjectTypeDef]:
"""Retrieve objects from S3 with pagination"""
try:
paginator = self.client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
return [
obj
for page in pages
for obj in page.get('Contents', [])
if obj.get('Size', 0) > 0
]
except Exception as e:
logging.error('Failed to retrieve S3 objects: %s', e)
raise
def get_s3_folders(self, bucket: str, base_folder: str) -> Set[str]:
"""Extract unique folder names from S3 keys"""
objects = self.list_objects(bucket, base_folder)
return {self._extract_dir_name(obj['Key'], base_folder) for obj in objects if obj.get('Key')} #type:ignore
@staticmethod
def _extract_dir_name(s3_key: str, base_folder: str) -> str:
"""Parse directory name from S3 key"""
if not base_folder.endswith('/'):
base_folder += '/'
if not s3_key.startswith(base_folder):
raise ValueError(f"S3 key {s3_key} does not start with base folder {base_folder}")
return PurePosixPath(s3_key[len(base_folder):]).parts[0]
class ParquetSchemaParser:
"""Handles Parquet schema parsing and SQL generation"""
@staticmethod
def generate_column_definitions(schema: parquet.ParquetSchema) -> str:
"""Convert Parquet schema to SQL column definitions"""
columns = []
for field in schema:
pyarrow_type = str(field.type).lower()
gp_type = DataTypeMapping.get_gp_type(pyarrow_type)
if not DataTypeMapping.has_mapping_for(pyarrow_type):
logging.warning('Unmapped Parquet type: %s (column: %s), using varchar',
pyarrow_type, field.name)
columns.append(f"{field.name.lower()} {gp_type}")
return ',\n'.join(columns)
class ExternalTableCreator:
"""Orchestrates the external table creation process"""
BUCKET_CONFIG = {
'events': ('dwh', 'events/'),
}
def __init__(self, s3_client: S3ClientManager, gp_conn, entity_type: str, s3_mask: str, stg_owner: str,
postfix: str = '', partition_suffix: str = ''):
self.s3_client = s3_client
self.gp_conn = gp_conn
self.entity_type = entity_type.lower()
self.s3_mask = s3_mask.lower()
self.postfix = postfix
self.partition_suffix = partition_suffix
self.stg_owner = stg_owner
if self.entity_type not in self.BUCKET_CONFIG:
raise ValueError(f'Invalid entity type: {entity_type}')
self.bucket, self.base_folder = self.BUCKET_CONFIG[self.entity_type]
def create_external_tables(self):
"""Main method to create tables for all matching entities"""
for entity_name in self._get_entity_names():
self._process_entity(entity_name)
def _get_entity_names(self) -> List[str]:
"""Determine entities to process based on S3 mask"""
if self.s3_mask == 'all':
return list(self.s3_client.get_s3_folders(self.bucket, self.base_folder))
return [self.s3_mask]
def _process_entity(self, entity_name: str):
"""Process individual entity table creation"""
table_name = f"{entity_name.lower()}{self.postfix}"
logging.info('Processing table: %s', table_name)
try:
# Get latest Parquet file
latest_file = self._get_latest_parquet_file(entity_name)
if not latest_file:
return
# Generate SQL definitions
columns_sql = self._read_parquet_schema(latest_file)
create_sql = self._generate_create_sql(columns_sql, entity_name, table_name)
# Execute SQL statements
with self.gp_conn.cursor() as cursor:
cursor.execute(f'DROP EXTERNAL TABLE IF EXISTS stg.{table_name}')
cursor.execute(create_sql)
cursor.execute(f'ALTER TABLE stg.{table_name} OWNER TO {self.stg_owner};')
self.gp_conn.commit()
logging.info('Successfully created table: %s', table_name)
except Exception as e:
logging.error('Failed to process entity %s: %s', entity_name, e)
self.gp_conn.rollback()
raise
def _get_latest_parquet_file(self, entity_name: str) -> ObjectTypeDef:
"""Find most recent Parquet file for entity"""
objects = self.s3_client.list_objects(
self.bucket, f"{self.base_folder}{entity_name}/"
)
valid_objects = [obj for obj in objects if obj.get('Key') and obj.get('LastModified')]
if not valid_objects:
raise ValueError('No valid files found for mask: %s', entity_name)
return max(valid_objects, key=lambda x: x['LastModified']) # type:ignore
def _read_parquet_schema(self, s3_object: ObjectTypeDef) -> str:
"""Read schema from Parquet file on S3"""
fs = S3FileSystem(
access_key=self.s3_client.access_key,
secret_key=self.s3_client.secret_key,
endpoint_override=self.s3_client.endpoint,
scheme="https"
)
parquet_path = f"{self.bucket}/{s3_object['Key']}" # type:ignore
try:
parquet_file = parquet.ParquetFile(fs.open_input_file(parquet_path))
return ParquetSchemaParser.generate_column_definitions(parquet_file.schema)
except Exception as e:
logging.error('Failed to read Parquet schema: %s', e)
raise
def _generate_create_sql(self, columns_sql: str, entity_name: str, table_name: str) -> str:
"""Generate CREATE EXTERNAL TABLE SQL statement"""
params = [
'PROFILE=s3:parquet',
'COMPRESSION_CODEC=gz',
f'accesskey={quote_plus(self.s3_client.access_key)}',
f'secretkey={quote_plus(self.s3_client.secret_key)}',
f'endpoint={self.s3_client.endpoint}'
]
query = '&'.join(params)
s3_path = f"{self.base_folder}{entity_name}/"
if self.partition_suffix:
s3_path += f"{self.partition_suffix}/"
return f"""
CREATE EXTERNAL TABLE stg.{table_name} (
{columns_sql}
)
LOCATION (
'pxf://{self.bucket}/{s3_path}*?{query}'
)
ON ALL
FORMAT 'CUSTOM' (FORMATTER='pxfwritable_import')
ENCODING 'UTF8';"""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Create external tables in Greenplum from S3 Parquet files')
parser.add_argument('--entity-type', required=True, choices=['events'], help='Type of entity to process')
parser.add_argument('--s3-endpoint', required=True, help='S3 endpoint')
parser.add_argument('--s3-mask', required=True, help='S3 object name or "all" to process all tables')
parser.add_argument('--aws-key', required=True, help='AWS access key ID')
parser.add_argument('--aws-secret', required=True, help='AWS secret access key')
parser.add_argument('--gp-user', required=True, help='Greenplum username')
parser.add_argument('--gp-password', required=True, help='Greenplum password')
parser.add_argument('--gp-dbname', required=True, help='Greenplum database name')
parser.add_argument('--gp-host', required=True, help='Greenplum database host')
parser.add_argument('--gp-stg-owner', required=True, help='Greenplum stg table owner')
args = parser.parse_args()
try:
with (S3ClientManager(
access_key=args.aws_key,
secret_key=args.aws_secret,
endpoint=args.s3_endpoint,
) as s3_client,
GreenplumConnection(
user=args.gp_user,
password=args.gp_password,
dbname=args.gp_dbname,
host=args.gp_host,
) as gp_conn):
creator = ExternalTableCreator(
s3_client=s3_client,
gp_conn=gp_conn,
entity_type=args.entity_type,
s3_mask=args.s3_mask,
stg_owner=args.stg_owner,
)
creator.create_external_tables()
except Exception as e:
logging.error('Fatal error occurred: %s', e)
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment