Skip to content

Instantly share code, notes, and snippets.

@adriaanslechten
Created March 31, 2021 10:37
Show Gist options
  • Save adriaanslechten/829efd1a3bc1842fe283f64aa2a06a1e to your computer and use it in GitHub Desktop.
Save adriaanslechten/829efd1a3bc1842fe283f64aa2a06a1e to your computer and use it in GitHub Desktop.
glue_dathub_source
from functools import lru_cache
from typing import Any, List, Mapping, Union
import boto3
from botocore import client
@lru_cache(maxsize=1)
def get_glue_client() -> client:
"""Instantiates a client."""
return boto3.client("glue", region_name="eu-west-1")
def get_glue_tables(
db_name: str,
catalog_id: str,
pagination_config: Mapping[str, Union[int, str]] = {"PageSize": 1000},
glue_client: client = get_glue_client(),
) -> List[Mapping[str, Any]]:
"""Returns all the tables given a DB."""
response_iterator = glue_client.get_paginator("get_tables").paginate(
CatalogId=catalog_id, DatabaseName=db_name, PaginationConfig=pagination_config
)
return [table for response_page in response_iterator for table in response_page["TableList"]]
def get_glue_databases(
catalog_id: str,
pagination_config: Mapping[str, Union[int, str]] = {"PageSize": 1000},
glue_client: client = get_glue_client(),
) -> List[Mapping[str, Any]]:
"""Gets the Glue Databases.
Take note: this can be a pretty heavy call if maxitems is set too high."""
response_iterator = glue_client.get_paginator("get_databases").paginate(
CatalogId=catalog_id, PaginationConfig=pagination_config
)
return [database for response_page in response_iterator for database in response_page["DatabaseList"]]
"""Glue source module."""
import logging
import time
from dataclasses import asdict, dataclass
from typing import Any, Dict, Iterable, List, Mapping, Optional
from datahub.configuration.common import AllowDenyPattern, ConfigModel
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.source import Source, SourceReport
from datahub.ingestion.source.metadata_common import MetadataWorkUnit
from datahub.ingestion.source.sql_common import SQLSourceReport
from datahub.metadata import BooleanTypeClass, NullTypeClass, NumberTypeClass, StringTypeClass
from datahub.metadata.com.linkedin.pegasus2avro.common import AuditStamp
from datahub.metadata.com.linkedin.pegasus2avro.dataset import DatasetProperties
from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField, SchemaFieldDataType, SchemaMetadata
LOG = logging.getLogger(__name__)
@dataclass(frozen=True)
class GlueStorageDescriptor:
"""Storage descriptor"""
columns: List[Mapping[str, Any]]
location: Optional[str] = None
compressed: Optional[bool] = None
number_of_buckets: Optional[int] = None
serde_info: Optional[str] = None
sort_columns: Optional[str] = None
stored_as_sub_directories: Optional[bool] = None
@property
def storage_aspect(self) -> Dict[str, Any]:
"""Returns a PropertiesClass aspect which we can use to ingest
some extra metadata."""
return {k: str(v) for k, v in asdict(self).items() if None.__ne__(v) and k != "columns"}
@dataclass(frozen=True)
class GlueData:
"""Data holder"""
name: str
database_name: str
create_time: str
update_time: str
retention: str
partition_keys: str
table_type: str
catalog_id: str
parameters: Mapping
storage_descriptor: GlueStorageDescriptor
is_registered_with_lake_formation: Optional[bool] = None
@property
def dataset_name(self) -> str:
"""Dataset name getter."""
return f"{self.database_name}.{self.name}"
@classmethod
def from_dict(cls, input_dict: Mapping[str, Any]) -> "GlueData":
"""Factory method."""
descriptor = input_dict["StorageDescriptor"]
storage_descriptor = GlueStorageDescriptor(
columns=descriptor["Columns"],
location=descriptor.get("Location"),
compressed=descriptor.get("Compressed"),
number_of_buckets=descriptor.get("NumberOfBuckets"),
sort_columns=descriptor.get("SortColumns"),
stored_as_sub_directories=descriptor.get("StoredAsSubDirectories"),
serde_info=descriptor.get("SerdeInfo"),
)
return cls(
name=input_dict["Name"],
database_name=input_dict["DatabaseName"],
create_time=input_dict.get("CreateTime", ""),
update_time=input_dict.get("UpdateTime", ""),
retention=input_dict.get("Retention", ""),
storage_descriptor=storage_descriptor,
partition_keys=input_dict.get("PartitionKeys", ""),
table_type=input_dict.get("TableType", ""),
is_registered_with_lake_formation=input_dict.get("IsRegisteredWithLakeFormation"),
catalog_id=input_dict.get("CatalogId", ""),
parameters=input_dict.get("Parameters", {}),
)
@property
def properties_aspect(self) -> Dict[str, Any]:
"""Returns a PropertiesClass aspect which we can use to ingest
some extra metadata."""
aspects = {**asdict(self), **self.parameters}
return {
k: str(v)
for k, v in aspects.items()
if None.__ne__(v) and v != "" and k not in ["storage_descriptor", "parameters"]
}
class GlueConfig(ConfigModel):
"""Glue config."""
catalog_id: str
schema_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
table_pattern: AllowDenyPattern = AllowDenyPattern.allow_all()
user: str = "urn:li:corpuser:etl"
platform_urn: str = "urn:li:dataPlatform:glue"
version: int = 0
fabric_type: str = "DEV"
@dataclass
class GlueSource(Source):
"""Glue source class."""
config: GlueConfig
report: SourceReport = SQLSourceReport()
@classmethod
def create(cls, input_config: Dict[str, Any], ctx: PipelineContext) -> "GlueSource":
"""Factory method."""
return cls(ctx, GlueConfig(**input_config))
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
"""Standard workmethod. Loops over databases and tables.
Create a workunit for each table."""
databases = get_glue_databases(catalog_id=self.config.catalog_id, pagination_config={"PageSize": 1000})
for database in databases:
if not self.config.schema_pattern.allowed(database["Name"]):
self.report.report_dropped(database["Name"])
continue
tables = get_glue_tables(
db_name=database["Name"],
catalog_id=self.config.catalog_id,
pagination_config={"PageSize": 1000},
)
for table in tables:
if self.allowed_and_validated_table(table):
glue_data = GlueData.from_dict(table)
schema_metadata = get_schema_metadata(self, glue_data, self.config)
urn = get_dataset_urn(
self.config.platform_urn, glue_data.database_name, glue_data.name, self.config.fabric_type
)
glue_properties_aspect = get_properties_aspect(glue_data)
dataset_snapshot = DatasetSnapshot(urn=urn, aspects=[schema_metadata, glue_properties_aspect])
mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
work_unit = MetadataWorkUnit(urn, mce)
self.report.report_table_scanned(glue_data.name)
self.report.report_workunit(work_unit)
yield work_unit
def allowed_and_validated_table(self, table: Mapping[str, Any]) -> bool:
"""Random validation checks we build in.
Some could be moved to Pydantic."""
if not self.config.table_pattern.allowed(table["Name"]):
self.report.report_dropped(table["Name"])
return False
if not "StorageDescriptor" in table:
self.report.report_failure(table["Name"], "Could not find StorageDescriptor in table.")
return False
return True
def get_report(self) -> SourceReport:
"""Obligatory getter."""
return self.report
def get_schema_metadata(glue_source: GlueSource, glue_data: GlueData, glue_config: GlueConfig) -> SchemaMetadata:
"""Adds the metadata for a schema."""
sys_time = int(time.time() * 1000)
return SchemaMetadata(
schemaName=glue_data.dataset_name,
platform=glue_config.platform_urn,
version=glue_config.version,
created=AuditStamp(time=sys_time, actor=glue_config.user),
lastModified=AuditStamp(time=sys_time, actor=glue_config.user),
fields=get_schema_fields(glue_source, glue_data),
)
def get_schema_fields(glue_source: GlueSource, glue_data: GlueData) -> List[SchemaField]:
"""Returns a list containing column information."""
return [
SchemaField(
fieldPath=column["Name"],
nativeDataType=column["Type"],
type=get_column_type(glue_source, glue_data, column),
recursive=False,
)
for column in glue_data.storage_descriptor.columns
]
def get_column_type(glue_source: GlueSource, glue_data: GlueData, column: Mapping[str, str]) -> SchemaFieldDataType:
"""Returns the column types.
This probably should be expanded and potentially you can make it so that you can take extra type mappings from
the config."""
columm_type = column["Type"].lower()
if columm_type not in GLUE_DATAHUB_TYPE_MAPPING:
glue_source.report.report_warning(
glue_data.dataset_name,
f"Warning: column: {column['Name']}. Unable to map Glue type: {columm_type} to Datahub native type.",
)
return SchemaFieldDataType(type=GLUE_DATAHUB_TYPE_MAPPING.get(columm_type, NullTypeClass()))
GLUE_DATAHUB_TYPE_MAPPING = {
"int": NumberTypeClass(),
"bigint": NumberTypeClass(),
"string": StringTypeClass(),
"str": StringTypeClass(),
"bool": BooleanTypeClass(),
"double": NumberTypeClass(),
"float": NumberTypeClass(),
}
def get_properties_aspect(glue_data: GlueData) -> DatasetProperties:
"""Returns a merged properties aspect."""
return DatasetProperties(
description="Glue properties",
customProperties={**glue_data.properties_aspect, **glue_data.storage_descriptor.storage_aspect},
)
def get_dataset_urn(platform_urn: str, database_name: str, table_name: str, fabric_type: str) -> str:
"""Returns the general URN for a given platform, db name and table name."""
return f"urn:li:dataset:({platform_urn},{database_name}.{table_name},{fabric_type})"
import datetime
import pytest
from dateutil.tz import tzlocal
from pathlib import Path
from unittest.mock import patch
from ingestion.model.glue_model import GlueData, GlueStorageDescriptor
from ingestion.source.glue_source import GlueSource, get_schema_fields
def test_get_schema_fields():
out = get_schema_fields(
GlueSource(config={}, ctx="test"),
GlueData(
name="test_table",
database_name="test_db",
create_time=datetime.datetime(2020, 1, 5, 16, 40, 26, tzinfo=tzlocal()),
update_time=datetime.datetime(2020, 1, 5, 16, 58, 37, tzinfo=tzlocal()),
retention=0,
storage_descriptor=GlueStorageDescriptor(
columns=[
{"Name": "event_id", "Type": "string"},
{"Name": "payload", "Type": "string"},
],
location="",
),
partition_keys=[],
table_type="VIRTUAL_VIEW",
is_registered_with_lake_formation=False,
catalog_id="922587933573",
parameters={"comment": "Presto View", "presto_view": "true"},
),
)
assert out == [
{
"fieldPath": "event_id",
"jsonPath": None,
"nullable": False,
"description": None,
"type": {"type": {}},
"nativeDataType": "string",
"recursive": False,
"globalTags": None,
},
{
"fieldPath": "payload",
"jsonPath": None,
"nullable": False,
"description": None,
"type": {"type": {}},
"nativeDataType": "string",
"recursive": False,
"globalTags": None,
},
]
# output from glue tables.
@pytest.mark.parametrize(
"input, expected",
[
(
{
"Name": "test_table",
"DatabaseName": "test_db",
"createtime": datetime.datetime(2020, 1, 5, 16, 40, 26, tzinfo=tzlocal()),
"updatetime": datetime.datetime(2020, 1, 5, 16, 58, 37, tzinfo=tzlocal()),
"retention": 0,
"StorageDescriptor": {
"Columns": [
{"Name": "event_id", "Type": "string"},
{"Name": "payload", "Type": "string"},
],
"Location": "",
"Compressed": False,
"NumberOfBuckets": 0,
"SerdeInfo": {},
"SortColumns": [],
"StoredAsSubDirectories": False,
},
"PartitionKeys": [],
"ViewOriginalText": "/* Presto View: randomssha */",
"ViewExpandedText": "/* Presto View */",
"TableType": "VIRTUAL_VIEW",
"Parameters": {"comment": "Presto View", "presto_view": "true"},
"IsRegisteredWithLakeFormation": False,
"CatalogId": "922587933573",
},
GlueData(
name="test_table",
database_name="test_db",
create_time=datetime.datetime(2020, 1, 5, 16, 40, 26, tzinfo=tzlocal()),
update_time=datetime.datetime(2020, 1, 5, 16, 58, 37, tzinfo=tzlocal()),
retention=0,
storage_descriptor=GlueStorageDescriptor(
columns=[
{"Name": "event_id", "Type": "string"},
{"Name": "payload", "Type": "string"},
],
location="",
compressed=False,
number_of_buckets=0,
serde_info={},
sort_columns=[],
stored_as_sub_directories=False,
),
partition_keys=[],
table_type="VIRTUAL_VIEW",
is_registered_with_lake_formation=False,
catalog_id="922587933573",
parameters={"comment": "Presto View", "presto_view": "true"},
),
),
],
)
def test_glue_model(input, expected):
glue_data = GlueData.from_dict(input)
assert glue_data.name == "test_table"
assert glue_data.storage_descriptor.columns == [
{"Name": "event_id", "Type": "string"},
{"Name": "payload", "Type": "string"},
]
def test_storage_descriptor():
decscriptor = GlueStorageDescriptor(
columns=[
{"Name": "event_id", "Type": "string"},
{"Name": "payload", "Type": "string"},
],
location="testloc",
compressed=False,
number_of_buckets=0,
stored_as_sub_directories=False,
)
assert decscriptor.location == "testloc"
assert(
decscriptor.storage_aspect
== {
"location": "testloc",
"compressed": "False",
"number_of_buckets": "0",
"stored_as_sub_directories": "False",
}
)
def test_glue_console(glue_db, glue_table):
with patch(
"ingestion.source.glue_source.get_glue_databases",
lambda catalog_id, pagination_config: glue_db,
) as mock_databases, patch(
"ingestion.source.glue_source.get_glue_tables",
lambda db_name, catalog_id, pagination_config: glue_table,
) as mock_tables:
out = klarna_custom_ingestion(Path("test/fixtures/glue-console.yml"))
assert out.workunits_produced == 1 # We have one correct workunit.
assert out.filtered[0] == "not_allowed_db" # One which we filter.
assert "failed_table" in out.failures # One which fails.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment