Skip to content

Instantly share code, notes, and snippets.

@xydinesh
Forked from deontologician/netezza_dialect.py
Created March 20, 2014 20:13
Show Gist options
  • Save xydinesh/9672836 to your computer and use it in GitHub Desktop.
Save xydinesh/9672836 to your computer and use it in GitHub Desktop.
from sqlalchemy.dialects import registry
from sqlalchemy.engine import reflection
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.dialects.postgresql.base import (
PGDialect, PGTypeCompiler, PGCompiler, PGDDLCompiler, DOUBLE_PRECISION,
INTERVAL, TIME, TIMESTAMP)
from sqlalchemy import MetaData
import sqlalchemy.types as sqltypes
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import Select, text, bindparam
from sqlalchemy.ext.compiler import compiles
import re
# pylint:disable=R0901,W0212
class ST_GEOMETRY(sqltypes.Binary):
__visit_name__ = 'ST_GEOMETRY'
class BPCHAR(sqltypes.CHAR):
__visit_name__ = 'BPCHAR'
class BYTEINT(sqltypes.INTEGER):
__visit_name__ = 'BYTEINT'
# Weird types gleaned from _v_datatype
ischema_names = {
'bpchar': BPCHAR,
'st_geometry': ST_GEOMETRY,
'byteint': BYTEINT,
}
class NetezzaTypeCompiler(PGTypeCompiler):
'''Fills out unique netezza types'''
def visit_BPCHAR(self, _type):
return 'BPCHAR'
def visit_ST_GEOMETRY(self, type_):
return 'ST_GEOMETRY({})'.format(type_.length)
def visit_BYTEINT(self, _type):
return 'BYTEINT'
class NetezzaCompiler(PGCompiler):
'''Handles some quirks of netezza queries'''
def limit_clause(self, select):
'''Netezza doesn't allow sql params in the limit/offset piece'''
text = ""
if select._limit is not None:
text += " \n LIMIT {limit}".format(limit=int(select._limit))
if select._offset is not None:
if select._limit is None:
text += " \n LIMIT ALL"
text += " OFFSET {offset}".format(offset=int(select._offset))
return text
class DistributeOn(SchemaItem):
'''Represents a distribute on clause'''
def __init__(self, *column_names):
'''Use like:
my_table_1 = Table('my_table_1', metadata,
Column('id_key', BIGINT),
Column('nbr', BIGINT),
DistributeOn('id_key')
)
my_table_2 = Table('my_table_2', metadata,
Column('id_key', BIGINT),
Column('nbr', BIGINT),
DistributeOn('random')
)
'''
self.column_names = column_names if column_names else ('RANDOM',)
def _set_parent(self, parent):
self.parent = parent
parent.distribute_on = self
class CreateTableAs(Select):
"""Create a CREATE TABLE AS SELECT ... statement."""
def __init__(
self, columns, new_table_name, is_temporary=False, *arg, **kw):
'''Idea from: http://stackoverflow.com/a/19054719/180718'''
super(CreateTableAs, self).__init__(columns, *arg, **kw)
self.is_temporary = is_temporary
self.new_table_name = new_table_name
@compiles(CreateTableAs)
def s_create_table_as(element, compiler, **kwargs):
'''compiles a ctas statement'''
text = compiler.visit_select(element)
spec = ['CREATE', 'TABLE', element.new_table_name, 'AS SELECT']
if element.is_temporary:
spec.insert(1, 'TEMPORARY')
text = text.replace('SELECT', ' '.join(spec))
return text
class NetezzaDDLCompiler(PGDDLCompiler):
'''Adds Netezza specific DDL clauses'''
def post_create_table(self, table):
'''Adds the `distribute on` clause to create table expressions'''
clause = ' DISTRIBUTE ON {columns}'
if hasattr(table, 'distribute_on') and \
table.distribute_on.column_names[0].lower() != 'random':
column_list = ','.join(table.distribute_on.column_names)
columns = '({})'.format(column_list)
else:
columns = 'RANDOM'
return clause.format(columns=columns)
# Maps type ids to sqlalchemy types, plus whether they have variable precision
oid_datatype_map = {
16: (sqltypes.Boolean, False),
18: (sqltypes.CHAR, False),
20: (sqltypes.BigInteger, False),
21: (sqltypes.SmallInteger, False),
23: (sqltypes.Integer, False),
700: (sqltypes.REAL, False),
701: (DOUBLE_PRECISION, False),
1042: (BPCHAR, True),
1043: (sqltypes.String, True),
1082: (sqltypes.Date, False),
1083: (TIME, False),
1184: (TIMESTAMP, False),
1186: (INTERVAL, False),
1266: (TIMESTAMP, False),
1700: (sqltypes.Numeric, False),
2500: (sqltypes.SmallInteger, False),
2522: (sqltypes.NCHAR, True),
2530: (sqltypes.NVARCHAR, True),
2552: (ST_GEOMETRY, True),
2568: (sqltypes.VARBINARY, True),
}
class NetezzaODBC(PyODBCConnector, PGDialect):
'''Attempts to reuse as much as possible from the postgresql and pyodbc
dialects.
'''
name = 'netezza'
encoding = 'latin9'
default_paramstyle = 'qmark'
returns_unicode_strings = False
supports_native_enum = False
supports_sequences = True
sequences_optional = False
isolation_level = 'READ COMMITTED'
max_identifier_length = 63
type_compiler = NetezzaTypeCompiler
statement_compiler = NetezzaCompiler
ddl_compiler = NetezzaDDLCompiler
def initialize(self, connection):
super(NetezzaODBC, self).initialize(connection)
# PyODBC connector tries to set these to true...
self.supports_unicode_statements = False
self.supports_unicode_binds = False
self.returns_unicode_strings = True
self.convert_unicode = True
self.encoding = 'latin9'
self.ischema_names.update(ischema_names)
def has_table(self, connection, tablename, schema=None):
result = connection.execute(
"select count(*) from _v_table "
"where tablename = ?",
(tablename,)
).scalar()
return bool(result)
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
result = connection.execute(
"select tablename as name from _v_table "
"where tablename not like '_t_%'")
table_names = [r[0] for r in result]
return table_names
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
SQL_COLS = """
SELECT CAST(a.attname AS VARCHAR(64)) as name,
a.atttypid as typeid,
not a.attnotnull as nullable,
a.attcolleng as length,
a.format_type
FROM _v_relation_column a
WHERE a.name = :tablename
ORDER BY a.attnum
"""
s = text(SQL_COLS,
bindparams=[bindparam('tablename', type_=sqltypes.String)],
typemap={'name': sqltypes.String,
'typeid': sqltypes.Integer,
'nullable': sqltypes.Boolean,
'length': sqltypes.Integer,
'format_type': sqltypes.String,
})
c = connection.execute(s, tablename=table_name)
rows = c.fetchall()
# format columns
columns = []
for name, typeid, nullable, length, format_type in rows:
coltype_class, has_length = oid_datatype_map[typeid]
if coltype_class is sqltypes.Numeric:
precision, scale = re.match(
r'numeric\((\d+),(\d+)\)', format_type).groups()
coltype = coltype_class(int(precision), int(scale))
elif has_length:
coltype = coltype_class(length)
else:
coltype = coltype_class()
columns.append({
'name': name,
'type': coltype,
'nullable': nullable,
})
return columns
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
'''Netezza doesn't have PK/unique constraints'''
return {'constrained_columns': [], 'name': None}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
'''Netezza doesn't have foreign keys'''
return []
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
'''Netezza doesn't have indexes'''
return []
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
result = connection.execute(
"select viewname as name from _v_view"
"where viewname not like '_v_%'")
return [r[0] for r in result]
def get_isolation_level(self, connection):
return self.isolation_level
def _get_default_schema_name(self, connection):
'''Netezza doesn't use schemas'''
raise NotImplementedError
def _check_unicode_returns(self, connection):
'''Netezza doesn't *do* unicode (except in nchar & nvarchar)'''
pass
registry.register("netezza", "cameopaas.db.netezza_dialect", "NetezzaODBC")
registry.register(
"netezza.pyodbc", "cameopaas.db.netezza_dialect", "NetezzaODBC")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment