-
-
Save mzpqnxow/55d2477ddfdd71b21c52eea67cf19479 to your computer and use it in GitHub Desktop.
Patched version of pandas.io.sql to support PostgreSQL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Patched version to support PostgreSQL | |
(original version: https://github.com/pydata/pandas/blob/v0.13.1/pandas/io/sql.py) | |
Adapted functions are: | |
- added _write_postgresql | |
- updated table_exist | |
- updated get_sqltype | |
- updated get_schema | |
Collection of query wrappers / abstractions to both facilitate data | |
retrieval and to reduce dependency on DB-specific API. | |
""" | |
from __future__ import print_function | |
from datetime import datetime, date | |
from pandas.compat import range, lzip, map, zip | |
import pandas.compat as compat | |
import numpy as np | |
import traceback | |
from pandas.core.datetools import format as date_format | |
from pandas.core.api import DataFrame, isnull | |
#------------------------------------------------------------------------------ | |
# Helper execution function | |
def execute(sql, con, retry=True, cur=None, params=None): | |
""" | |
Execute the given SQL query using the provided connection object. | |
Parameters | |
---------- | |
sql: string | |
Query to be executed | |
con: database connection instance | |
Database connection. Must implement PEP249 (Database API v2.0). | |
retry: bool | |
Not currently implemented | |
cur: database cursor, optional | |
Must implement PEP249 (Datbase API v2.0). If cursor is not provided, | |
one will be obtained from the database connection. | |
params: list or tuple, optional | |
List of parameters to pass to execute method. | |
Returns | |
------- | |
Cursor object | |
""" | |
try: | |
if cur is None: | |
cur = con.cursor() | |
if params is None: | |
cur.execute(sql) | |
else: | |
cur.execute(sql, params) | |
return cur | |
except Exception: | |
try: | |
con.rollback() | |
except Exception: # pragma: no cover | |
pass | |
print('Error on sql %s' % sql) | |
raise | |
def _safe_fetch(cur): | |
try: | |
result = cur.fetchall() | |
if not isinstance(result, list): | |
result = list(result) | |
return result | |
except Exception as e: # pragma: no cover | |
excName = e.__class__.__name__ | |
if excName == 'OperationalError': | |
return [] | |
def tquery(sql, con=None, cur=None, retry=True): | |
""" | |
Returns list of tuples corresponding to each row in given sql | |
query. | |
If only one column selected, then plain list is returned. | |
Parameters | |
---------- | |
sql: string | |
SQL query to be executed | |
con: SQLConnection or DB API 2.0-compliant connection | |
cur: DB API 2.0 cursor | |
Provide a specific connection or a specific cursor if you are executing a | |
lot of sequential statements and want to commit outside. | |
""" | |
cur = execute(sql, con, cur=cur) | |
result = _safe_fetch(cur) | |
if con is not None: | |
try: | |
cur.close() | |
con.commit() | |
except Exception as e: | |
excName = e.__class__.__name__ | |
if excName == 'OperationalError': # pragma: no cover | |
print('Failed to commit, may need to restart interpreter') | |
else: | |
raise | |
traceback.print_exc() | |
if retry: | |
return tquery(sql, con=con, retry=False) | |
if result and len(result[0]) == 1: | |
# python 3 compat | |
result = list(lzip(*result)[0]) | |
elif result is None: # pragma: no cover | |
result = [] | |
return result | |
def uquery(sql, con=None, cur=None, retry=True, params=None): | |
""" | |
Does the same thing as tquery, but instead of returning results, it | |
returns the number of rows affected. Good for update queries. | |
""" | |
cur = execute(sql, con, cur=cur, retry=retry, params=params) | |
result = cur.rowcount | |
try: | |
con.commit() | |
except Exception as e: | |
excName = e.__class__.__name__ | |
if excName != 'OperationalError': | |
raise | |
traceback.print_exc() | |
if retry: | |
print('Looks like your connection failed, reconnecting...') | |
return uquery(sql, con, retry=False) | |
return result | |
def read_frame(sql, con, index_col=None, coerce_float=True, params=None): | |
""" | |
Returns a DataFrame corresponding to the result set of the query | |
string. | |
Optionally provide an index_col parameter to use one of the | |
columns as the index. Otherwise will be 0 to len(results) - 1. | |
Parameters | |
---------- | |
sql: string | |
SQL query to be executed | |
con: DB connection object, optional | |
index_col: string, optional | |
column name to use for the returned DataFrame object. | |
coerce_float : boolean, default True | |
Attempt to convert values to non-string, non-numeric objects (like | |
decimal.Decimal) to floating point, useful for SQL result sets | |
params: list or tuple, optional | |
List of parameters to pass to execute method. | |
""" | |
cur = execute(sql, con, params=params) | |
rows = _safe_fetch(cur) | |
columns = [col_desc[0] for col_desc in cur.description] | |
cur.close() | |
con.commit() | |
result = DataFrame.from_records(rows, columns=columns, | |
coerce_float=coerce_float) | |
if index_col is not None: | |
result = result.set_index(index_col) | |
return result | |
frame_query = read_frame | |
read_sql = read_frame | |
def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): | |
""" | |
Write records stored in a DataFrame to a SQL database. | |
Parameters | |
---------- | |
frame: DataFrame | |
name: name of SQL table | |
con: an open SQL database connection object | |
flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' | |
if_exists: {'fail', 'replace', 'append'}, default 'fail' | |
fail: If table exists, do nothing. | |
replace: If table exists, drop it, recreate it, and insert data. | |
append: If table exists, insert data. Create if does not exist. | |
""" | |
if 'append' in kwargs: | |
import warnings | |
warnings.warn("append is deprecated, use if_exists instead", | |
FutureWarning) | |
if kwargs['append']: | |
if_exists = 'append' | |
else: | |
if_exists = 'fail' | |
if if_exists not in ('fail', 'replace', 'append'): | |
raise ValueError("'%s' is not valid for if_exists" % if_exists) | |
exists = table_exists(name, con, flavor) | |
if if_exists == 'fail' and exists: | |
raise ValueError("Table '%s' already exists." % name) | |
# creation/replacement dependent on the table existing and if_exist criteria | |
create = None | |
if exists: | |
if if_exists == 'fail': | |
raise ValueError("Table '%s' already exists." % name) | |
elif if_exists == 'replace': | |
cur = con.cursor() | |
cur.execute("DROP TABLE %s;" % name) | |
cur.close() | |
create = get_schema(frame, name, flavor) | |
else: | |
create = get_schema(frame, name, flavor) | |
if create is not None: | |
cur = con.cursor() | |
cur.execute(create) | |
cur.close() | |
cur = con.cursor() | |
# Replace spaces in DataFrame column names with _. | |
safe_names = [s.replace(' ', '_').strip() for s in frame.columns] | |
flavor_picker = {'sqlite' : _write_sqlite, | |
'mysql' : _write_mysql, | |
'postgresql' : _write_postgresql} | |
func = flavor_picker.get(flavor, None) | |
if func is None: | |
raise NotImplementedError | |
func(frame, name, safe_names, cur) | |
cur.close() | |
con.commit() | |
def _write_sqlite(frame, table, names, cur): | |
bracketed_names = ['[' + column + ']' for column in names] | |
col_names = ','.join(bracketed_names) | |
wildcards = ','.join(['?'] * len(names)) | |
insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( | |
table, col_names, wildcards) | |
# pandas types are badly handled if there is only 1 column ( Issue #3628 ) | |
if not len(frame.columns) == 1: | |
data = [tuple(x) for x in frame.values] | |
else: | |
data = [tuple(x) for x in frame.values.tolist()] | |
cur.executemany(insert_query, data) | |
def _write_mysql(frame, table, names, cur): | |
bracketed_names = ['`' + column + '`' for column in names] | |
col_names = ','.join(bracketed_names) | |
wildcards = ','.join([r'%s'] * len(names)) | |
insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( | |
table, col_names, wildcards) | |
data = [tuple(x) for x in frame.values] | |
cur.executemany(insert_query, data) | |
def _write_postgresql(frame, table, names, cur): | |
bracketed_names = ['"' + column + '"' for column in names] | |
col_names = ','.join(bracketed_names) | |
wildcards = ','.join([r'%s'] * len(names)) | |
insert_query = 'INSERT INTO public.%s (%s) VALUES (%s)' % ( | |
table, col_names, wildcards) | |
data = [tuple(x) for x in frame.values] | |
print insert_query | |
print data | |
cur.executemany(insert_query, data) | |
def table_exists(name, con, flavor): | |
flavor_map = { | |
'sqlite': ("SELECT name FROM sqlite_master " | |
"WHERE type='table' AND name='%s';") % name, | |
'mysql' : "SHOW TABLES LIKE '%s'" % name, | |
'postgresql' : "SELECT * FROM pg_catalog.pg_tables where tablename = '%s'" % name} | |
query = flavor_map.get(flavor, None) | |
# if query is None: | |
# raise NotImplementedError | |
return len(tquery(query, con)) > 0 | |
def get_sqltype(pytype, flavor): | |
sqltype = {'mysql': 'VARCHAR (63)', | |
'sqlite': 'TEXT', | |
'postgresql': 'VARCHAR (63)'} | |
if issubclass(pytype, np.floating): | |
sqltype['mysql'] = 'FLOAT' | |
sqltype['sqlite'] = 'REAL' | |
sqltype['postgresql'] = 'double precision' | |
if issubclass(pytype, np.integer): | |
#TODO: Refine integer size. | |
sqltype['mysql'] = 'BIGINT' | |
sqltype['sqlite'] = 'INTEGER' | |
sqltype['postgresql'] = 'integer' | |
if issubclass(pytype, np.datetime64) or pytype is datetime: | |
# Caution: np.datetime64 is also a subclass of np.number. | |
sqltype['mysql'] = 'DATETIME' | |
sqltype['sqlite'] = 'TIMESTAMP' | |
sqltype['postgresql'] = 'timestamp' | |
if pytype is datetime.date: | |
sqltype['mysql'] = 'DATE' | |
sqltype['sqlite'] = 'TIMESTAMP' | |
sqltype['postgresql'] = 'date' | |
if issubclass(pytype, np.bool_): | |
sqltype['sqlite'] = 'INTEGER' | |
sqltype['postgresql'] = 'boolean' | |
return sqltype[flavor] | |
def get_schema(frame, name, flavor, keys=None): | |
"Return a CREATE TABLE statement to suit the contents of a DataFrame." | |
lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) | |
# Replace spaces in DataFrame column names with _. | |
safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] | |
column_types = lzip(safe_columns, map(lookup_type, frame.dtypes)) | |
if flavor == 'sqlite': | |
columns = ',\n '.join('[%s] %s' % x for x in column_types) | |
elif flavor == 'postgresql': | |
columns = ',\n '.join('"%s" %s' % x for x in column_types) | |
else: | |
columns = ',\n '.join('`%s` %s' % x for x in column_types) | |
keystr = '' | |
if keys is not None: | |
if isinstance(keys, compat.string_types): | |
keys = (keys,) | |
keystr = ', PRIMARY KEY (%s)' % ','.join(keys) | |
template = """CREATE TABLE %(name)s ( | |
%(columns)s | |
%(keystr)s | |
);""" | |
create_statement = template % {'name': name, 'columns': columns, | |
'keystr': keystr} | |
return create_statement | |
def sequence2dict(seq): | |
"""Helper function for cx_Oracle. | |
For each element in the sequence, creates a dictionary item equal | |
to the element and keyed by the position of the item in the list. | |
>>> sequence2dict(("Matt", 1)) | |
{'1': 'Matt', '2': 1} | |
Source: | |
http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ | |
""" | |
d = {} | |
for k, v in zip(range(1, 1 + len(seq)), seq): | |
d[str(k)] = v | |
return d |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment