Created
August 13, 2018 22:23
-
-
Save rvaidya/f02c0e72e296b2906c76f0a94399d01e to your computer and use it in GitHub Desktop.
Dask DataFrame read_sql_table using sqlalchemy reflection to detect column types
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask.dataframe import read_sql_table | |
import pandas as pd | |
import numpy as np | |
from sqlalchemy import create_engine, schema | |
from config import database_config | |
# Copied from pandas with modifications | |
def _get_dtype(column, sqltype): | |
from sqlalchemy.types import (Integer, Float, Boolean, DateTime, | |
Date, TIMESTAMP) | |
if isinstance(sqltype, Float): | |
return float | |
elif isinstance(sqltype, Integer): | |
if column.nullable: | |
return float | |
# TODO: Refine integer size. | |
return np.dtype('int64') | |
elif isinstance(sqltype, TIMESTAMP): | |
# we have a timezone capable type | |
if not sqltype.timezone: | |
return np.dtype('datetime64[ns]') | |
return DatetimeTZDtype | |
elif isinstance(sqltype, DateTime): | |
# Caution: np.datetime64 is also a subclass of np.number. | |
return np.dtype('datetime64[ns]') | |
elif isinstance(sqltype, Date): | |
return np.date | |
elif isinstance(sqltype, Boolean): | |
return bool | |
return object | |
def database_table_request(db_type: str, db_server: str, database: str, table: str, index_col: str = None, npartitions: int = 1): | |
db_engine = database_config.engine(db_type) | |
db_username = database_config.username(db_type) | |
db_password = database_config.password(db_type) | |
# Get database schema using sqlalchemy reflection | |
db_uri = f'{db_engine}://{db_username}:{db_password}@{db_server}/{database}' | |
db_engine = create_engine(db_uri) | |
db_metadata = schema.MetaData(bind=db_engine, reflect=True) | |
db_tables = {k.lower(): v for k, v in db_metadata.tables.items()} | |
db_table = db_tables[table.lower()] | |
# Identify the PK if it hasn't been passed | |
for column in db_table.columns: | |
if column.primary_key and index_col == None: | |
index_col = column.name | |
# Now that we have a PK name, create an empty pandas DataFrame | |
# for Dask meta argument | |
pd_df = pd.DataFrame(index=None) | |
for column in db_table.columns: | |
if not column.name == index_col: | |
pd_df[column.name] = pd.Series( | |
dtype=_get_dtype(column, column.type)) | |
# Execute query here | |
df = read_sql_table(db_table.name, db_uri, index_col, | |
meta=pd_df, npartitions=npartitions) | |
# Return dataframe | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment