-
-
Save xydinesh/9672836 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy.dialects import registry | |
from sqlalchemy.engine import reflection | |
from sqlalchemy.connectors.pyodbc import PyODBCConnector | |
from sqlalchemy.dialects.postgresql.base import ( | |
PGDialect, PGTypeCompiler, PGCompiler, PGDDLCompiler, DOUBLE_PRECISION, | |
INTERVAL, TIME, TIMESTAMP) | |
from sqlalchemy import MetaData | |
import sqlalchemy.types as sqltypes | |
from sqlalchemy.schema import SchemaItem | |
from sqlalchemy.sql import Select, text, bindparam | |
from sqlalchemy.ext.compiler import compiles | |
import re | |
# pylint:disable=R0901,W0212 | |
class ST_GEOMETRY(sqltypes.Binary): | |
__visit_name__ = 'ST_GEOMETRY' | |
class BPCHAR(sqltypes.CHAR): | |
__visit_name__ = 'BPCHAR' | |
class BYTEINT(sqltypes.INTEGER): | |
__visit_name__ = 'BYTEINT' | |
# Weird types gleaned from _v_datatype | |
ischema_names = { | |
'bpchar': BPCHAR, | |
'st_geometry': ST_GEOMETRY, | |
'byteint': BYTEINT, | |
} | |
class NetezzaTypeCompiler(PGTypeCompiler): | |
'''Fills out unique netezza types''' | |
def visit_BPCHAR(self, _type): | |
return 'BPCHAR' | |
def visit_ST_GEOMETRY(self, type_): | |
return 'ST_GEOMETRY({})'.format(type_.length) | |
def visit_BYTEINT(self, _type): | |
return 'BYTEINT' | |
class NetezzaCompiler(PGCompiler): | |
'''Handles some quirks of netezza queries''' | |
def limit_clause(self, select): | |
'''Netezza doesn't allow sql params in the limit/offset piece''' | |
text = "" | |
if select._limit is not None: | |
text += " \n LIMIT {limit}".format(limit=int(select._limit)) | |
if select._offset is not None: | |
if select._limit is None: | |
text += " \n LIMIT ALL" | |
text += " OFFSET {offset}".format(offset=int(select._offset)) | |
return text | |
class DistributeOn(SchemaItem): | |
'''Represents a distribute on clause''' | |
def __init__(self, *column_names): | |
'''Use like: | |
my_table_1 = Table('my_table_1', metadata, | |
Column('id_key', BIGINT), | |
Column('nbr', BIGINT), | |
DistributeOn('id_key') | |
) | |
my_table_2 = Table('my_table_2', metadata, | |
Column('id_key', BIGINT), | |
Column('nbr', BIGINT), | |
DistributeOn('random') | |
) | |
''' | |
self.column_names = column_names if column_names else ('RANDOM',) | |
def _set_parent(self, parent): | |
self.parent = parent | |
parent.distribute_on = self | |
class CreateTableAs(Select): | |
"""Create a CREATE TABLE AS SELECT ... statement.""" | |
def __init__( | |
self, columns, new_table_name, is_temporary=False, *arg, **kw): | |
'''Idea from: http://stackoverflow.com/a/19054719/180718''' | |
super(CreateTableAs, self).__init__(columns, *arg, **kw) | |
self.is_temporary = is_temporary | |
self.new_table_name = new_table_name | |
@compiles(CreateTableAs) | |
def s_create_table_as(element, compiler, **kwargs): | |
'''compiles a ctas statement''' | |
text = compiler.visit_select(element) | |
spec = ['CREATE', 'TABLE', element.new_table_name, 'AS SELECT'] | |
if element.is_temporary: | |
spec.insert(1, 'TEMPORARY') | |
text = text.replace('SELECT', ' '.join(spec)) | |
return text | |
class NetezzaDDLCompiler(PGDDLCompiler): | |
'''Adds Netezza specific DDL clauses''' | |
def post_create_table(self, table): | |
'''Adds the `distribute on` clause to create table expressions''' | |
clause = ' DISTRIBUTE ON {columns}' | |
if hasattr(table, 'distribute_on') and \ | |
table.distribute_on.column_names[0].lower() != 'random': | |
column_list = ','.join(table.distribute_on.column_names) | |
columns = '({})'.format(column_list) | |
else: | |
columns = 'RANDOM' | |
return clause.format(columns=columns) | |
# Maps type ids to sqlalchemy types, plus whether they have variable precision | |
oid_datatype_map = { | |
16: (sqltypes.Boolean, False), | |
18: (sqltypes.CHAR, False), | |
20: (sqltypes.BigInteger, False), | |
21: (sqltypes.SmallInteger, False), | |
23: (sqltypes.Integer, False), | |
700: (sqltypes.REAL, False), | |
701: (DOUBLE_PRECISION, False), | |
1042: (BPCHAR, True), | |
1043: (sqltypes.String, True), | |
1082: (sqltypes.Date, False), | |
1083: (TIME, False), | |
1184: (TIMESTAMP, False), | |
1186: (INTERVAL, False), | |
1266: (TIMESTAMP, False), | |
1700: (sqltypes.Numeric, False), | |
2500: (sqltypes.SmallInteger, False), | |
2522: (sqltypes.NCHAR, True), | |
2530: (sqltypes.NVARCHAR, True), | |
2552: (ST_GEOMETRY, True), | |
2568: (sqltypes.VARBINARY, True), | |
} | |
class NetezzaODBC(PyODBCConnector, PGDialect): | |
'''Attempts to reuse as much as possible from the postgresql and pyodbc | |
dialects. | |
''' | |
name = 'netezza' | |
encoding = 'latin9' | |
default_paramstyle = 'qmark' | |
returns_unicode_strings = False | |
supports_native_enum = False | |
supports_sequences = True | |
sequences_optional = False | |
isolation_level = 'READ COMMITTED' | |
max_identifier_length = 63 | |
type_compiler = NetezzaTypeCompiler | |
statement_compiler = NetezzaCompiler | |
ddl_compiler = NetezzaDDLCompiler | |
def initialize(self, connection): | |
super(NetezzaODBC, self).initialize(connection) | |
# PyODBC connector tries to set these to true... | |
self.supports_unicode_statements = False | |
self.supports_unicode_binds = False | |
self.returns_unicode_strings = True | |
self.convert_unicode = True | |
self.encoding = 'latin9' | |
self.ischema_names.update(ischema_names) | |
def has_table(self, connection, tablename, schema=None): | |
result = connection.execute( | |
"select count(*) from _v_table " | |
"where tablename = ?", | |
(tablename,) | |
).scalar() | |
return bool(result) | |
@reflection.cache | |
def get_table_names(self, connection, schema=None, **kw): | |
result = connection.execute( | |
"select tablename as name from _v_table " | |
"where tablename not like '_t_%'") | |
table_names = [r[0] for r in result] | |
return table_names | |
@reflection.cache | |
def get_columns(self, connection, table_name, schema=None, **kw): | |
SQL_COLS = """ | |
SELECT CAST(a.attname AS VARCHAR(64)) as name, | |
a.atttypid as typeid, | |
not a.attnotnull as nullable, | |
a.attcolleng as length, | |
a.format_type | |
FROM _v_relation_column a | |
WHERE a.name = :tablename | |
ORDER BY a.attnum | |
""" | |
s = text(SQL_COLS, | |
bindparams=[bindparam('tablename', type_=sqltypes.String)], | |
typemap={'name': sqltypes.String, | |
'typeid': sqltypes.Integer, | |
'nullable': sqltypes.Boolean, | |
'length': sqltypes.Integer, | |
'format_type': sqltypes.String, | |
}) | |
c = connection.execute(s, tablename=table_name) | |
rows = c.fetchall() | |
# format columns | |
columns = [] | |
for name, typeid, nullable, length, format_type in rows: | |
coltype_class, has_length = oid_datatype_map[typeid] | |
if coltype_class is sqltypes.Numeric: | |
precision, scale = re.match( | |
r'numeric\((\d+),(\d+)\)', format_type).groups() | |
coltype = coltype_class(int(precision), int(scale)) | |
elif has_length: | |
coltype = coltype_class(length) | |
else: | |
coltype = coltype_class() | |
columns.append({ | |
'name': name, | |
'type': coltype, | |
'nullable': nullable, | |
}) | |
return columns | |
@reflection.cache | |
def get_pk_constraint(self, connection, table_name, schema=None, **kw): | |
'''Netezza doesn't have PK/unique constraints''' | |
return {'constrained_columns': [], 'name': None} | |
@reflection.cache | |
def get_foreign_keys(self, connection, table_name, schema=None, **kw): | |
'''Netezza doesn't have foreign keys''' | |
return [] | |
@reflection.cache | |
def get_indexes(self, connection, table_name, schema=None, **kw): | |
'''Netezza doesn't have indexes''' | |
return [] | |
@reflection.cache | |
def get_view_names(self, connection, schema=None, **kw): | |
result = connection.execute( | |
"select viewname as name from _v_view" | |
"where viewname not like '_v_%'") | |
return [r[0] for r in result] | |
def get_isolation_level(self, connection): | |
return self.isolation_level | |
def _get_default_schema_name(self, connection): | |
'''Netezza doesn't use schemas''' | |
raise NotImplementedError | |
def _check_unicode_returns(self, connection): | |
'''Netezza doesn't *do* unicode (except in nchar & nvarchar)''' | |
pass | |
registry.register("netezza", "cameopaas.db.netezza_dialect", "NetezzaODBC") | |
registry.register( | |
"netezza.pyodbc", "cameopaas.db.netezza_dialect", "NetezzaODBC") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment