-
-
Save szs8/3668191 to your computer and use it in GitHub Desktop.
Python PANDAS : load and save Dataframes to sqlite, MySQL, Oracle, Postgres
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
example use of pandas with oracle mysql postgresql sqlite | |
lightly tested. | |
to do: | |
save/restore index (how to check table existence? just do select count(*)?), | |
finish odbc, | |
add booleans?, | |
sql_server? | |
""" | |
from datetime import datetime | |
import cStringIO #for file-like objects | |
import numpy as np | |
from pandas import * | |
import pandas.io.sql as psql | |
# for building create table schemas with appropriate type names | |
dbtypes={ | |
'mysql' : {'DATE':'DATE', 'DATETIME':'DATETIME', 'INT':'INT', 'FLOAT':'FLOAT', 'VARCHAR':'VARCHAR'}, | |
'oracle': {'DATE':'DATE', 'DATETIME':'DATE', 'INT':'NUMBER', 'FLOAT':'NUMBER', 'VARCHAR':'VARCHAR2'}, | |
'sqlite': {'DATE':'TIMESTAMP', 'DATETIME':'TIMESTAMP', 'INT':'NUMBER', 'FLOAT':'NUMBER', 'VARCHAR':'VARCHAR2'}, | |
'postgresql': {'DATE':'TIMESTAMP', 'DATETIME':'TIMESTAMP', 'INT':'INTEGER', 'FLOAT':'REAL', 'VARCHAR':'TEXT'}, | |
} | |
def get_schema(frame, name, flavor): | |
'''build a create table string from a dataframe for the specified flavor of dbms''' | |
types = dbtypes[flavor] #deal with datatype differences | |
column_types = [] | |
dtypes = frame.dtypes | |
for i,k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
#print 'dtype', dt, dt.itemsize | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
sqltype = types['DATETIME'] | |
elif issubclass(dt.type, np.datetime64): | |
sqltype = types['DATETIME'] | |
elif issubclass(dt.type, (np.integer, np.bool_)): | |
sqltype = types['INT'] | |
elif issubclass(dt.type, np.floating): | |
sqltype = types['FLOAT'] | |
else: | |
sampl = frame[ frame.columns[i] ][0] | |
#print 'other', type(sampl) | |
if str(type(sampl))=="<type 'datetime.datetime'>": | |
sqltype = types['DATETIME'] | |
elif str(type(sampl))=="<type 'datetime.date'>": | |
sqltype = types['DATE'] | |
else: | |
if flavor in ('mysql','oracle'): | |
size = 2 + max( (len(str(a)) for a in frame[k]) ) | |
print k,'varchar sz', size | |
sqltype = types['VARCHAR'] + '(?)'.replace('?', str(size) ) | |
else: | |
sqltype = types['VARCHAR'] | |
column_types.append((k, sqltype)) | |
columns = ',\n '.join('%s %s' % x for x in column_types) | |
template_create = """CREATE TABLE %(name)s ( | |
%(columns)s | |
);""" | |
return template_create % {'name' : name, 'columns' : columns} | |
def read_db(sql, con): | |
'''send SELECT to server and return a dataframe''' | |
return psql.frame_query(sql, con) | |
def table_exists(name=None, con=None, flavor='sqlite'): | |
'''check whether this table exists on the server already. how to do in ODBC?''' | |
if flavor == 'sqlite': | |
sql="SELECT name FROM sqlite_master WHERE type='table' AND name='MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'mysql': | |
sql="show tables like 'MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'postgresql': | |
sql= "SELECT * FROM pg_tables WHERE tablename='MYTABLE';".replace('MYTABLE', name) | |
elif flavor == 'oracle': | |
sql="select table_name from user_tables where table_name='MYTABLE'".replace('MYTABLE', name.upper()) | |
else: | |
raise NotImplementedError | |
df = read_db(sql, con) | |
print sql, df | |
print 'table_exists?', len(df) | |
exists = True if len(df)>0 else False | |
return exists | |
def write_frame(frame, name=None, con=None, flavor='sqlite', if_exists='fail'): | |
""" | |
Write records stored in a DataFrame to specified dbms. | |
if_exists: | |
'fail' - create table will be attempted and fail | |
'replace' - if table with 'name' exists, it will be deleted | |
'append' - assume table with correct schema exists and add data. if no table or bad data, then fail. | |
??? if table doesn't exist, make it. | |
if table already exists. Add: if_exists=('replace','append','fail') | |
""" | |
if if_exists=='replace' and table_exists(name, con, flavor): | |
cur = con.cursor() | |
cur.execute("drop table "+name) | |
cur.close() | |
if if_exists in ('fail','replace') or ( if_exists=='append' and table_exists(name, con, flavor)==False ): | |
#create table | |
schema = get_schema(frame, name, flavor) | |
if flavor=='oracle': | |
schema = schema.replace(';','') | |
cur = con.cursor() | |
print 'schema\n', schema | |
cur.execute(schema) | |
print 'created table' | |
if flavor=='sqlite': | |
wildcards = ','.join(['?'] * len(frame.columns)) | |
insert_sql = 'INSERT INTO %s VALUES (%s)' % (name, wildcards) | |
print 'insert_sql', insert_sql | |
data = [tuple(x) for x in frame.values] | |
print 'data', data | |
cur.executemany(insert_sql, data) | |
elif flavor=='oracle': | |
cols=[k for k in frame.dtypes.index] | |
colnames = ','.join(cols) | |
colpos = ', '.join([':'+str(i+1) for i,f in enumerate(cols)]) | |
insert_sql = 'INSERT INTO %s (%s) VALUES (%s)' % (name, colnames, colpos) | |
print 'insert_sql', insert_sql | |
data = [ convertSequenceToDict(rec) for rec in frame.values] | |
print data | |
cur.executemany(insert_sql, data) | |
elif flavor=='mysql': | |
wildcards = ','.join(['%s'] * len(frame.columns)) | |
cols=[k for k in frame.dtypes.index] | |
colnames = ','.join(cols) | |
insert_sql = 'INSERT INTO %s (%s) VALUES (%s)' % (name, colnames, wildcards) | |
data = [tuple(x) for x in frame.values] | |
cur.executemany(insert_sql, data) | |
elif flavor=='postgresql': | |
postgresql_copy_from(frame, name, con) | |
else: | |
raise NotImplementedError | |
con.commit() | |
cur.close() | |
return | |
def postgresql_copy_from(df, name, con ): | |
# append data into existing postgresql table using COPY for speed | |
# 1. convert df to csv no header | |
output = cStringIO.StringIO() | |
# deal with datetime64 to_csv() bug | |
have_datetime64 = False | |
dtypes = df.dtypes | |
for i, k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
print 'dtype', dt, dt.itemsize | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
have_datetime64 = True | |
if have_datetime64: | |
d2=df.copy() | |
for i, k in enumerate(dtypes.index): | |
dt = dtypes[k] | |
if str(dt.type)=="<type 'numpy.datetime64'>": | |
d2[k] = [ v.to_pydatetime() for v in d2[k] ] | |
#convert datetime64 to datetime | |
#ddt= [v.to_pydatetime() for v in dd] #convert datetime64 to datetime | |
d2.to_csv(output, sep='\t', header=False) | |
else: | |
df.to_csv(output, sep='\t', header=False) | |
contents = output.getvalue() | |
print 'contents\n', contents | |
# 2. copy from | |
cur = con.cursor() | |
cur.copy_from(output, name) | |
con.commit() | |
cur.close() | |
return | |
#source: http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ | |
def convertSequenceToDict(list): | |
"""for cx_Oracle: | |
For each element in the sequence, creates a dictionary item equal | |
to the element and keyed by the position of the item in the list. | |
>>> convertListToDict(("Matt", 1)) | |
{'1': 'Matt', '2': 1} | |
""" | |
dict = {} | |
argList = range(1,len(list)+1) | |
for k,v in zip(argList, list): | |
dict[str(k)] = v | |
return dict | |
############################################################################### | |
def test_sqlite(name, testdf): | |
print '\nsqlite, using detect_types=sqlite3.PARSE_DECLTYPES for datetimes' | |
import sqlite3 | |
with sqlite3.connect('test.db', detect_types=sqlite3.PARSE_DECLTYPES) as conn: | |
#conn.row_factory = sqlite3.Row | |
write_frame(testdf, name, con=conn, flavor='sqlite', if_exists='replace') | |
df_sqlite = read_db('select * from '+name, con=conn) | |
print 'loaded dataframe from sqlite', len(df_sqlite) | |
print 'done with sqlite' | |
def test_oracle(name, testdf): | |
print '\nOracle' | |
import cx_Oracle | |
with cx_Oracle.connect('YOURCONNECTION') as ora_conn: | |
testdf['d64'] = np.datetime64( testdf['hire_date'] ) | |
write_frame(testdf, name, con=ora_conn, flavor='oracle', if_exists='replace') | |
df_ora2 = read_db('select * from '+name, con=ora_conn) | |
print 'done with oracle' | |
return df_ora2 | |
def test_postgresql(name, testdf): | |
#from pg8000 import DBAPI as pg | |
import psycopg2 as pg | |
print '\nPostgresQL, Greenplum' | |
pgcn = pg.connect(YOURCONNECTION) | |
print 'df frame_query' | |
try: | |
write_frame(testdf, name, con=pgcn, flavor='postgresql', if_exists='replace') | |
print 'pg copy_from' | |
postgresql_copy_from(testdf, name, con=pgcn) | |
df_gp = read_db('select * from '+name, con=pgcn) | |
print 'loaded dataframe from greenplum', len(df_gp) | |
finally: | |
pgcn.commit() | |
pgcn.close() | |
print 'done with greenplum' | |
def test_mysql(name, testdf): | |
import MySQLdb | |
print '\nmysql' | |
cn= MySQLdb.connect(YOURCONNECTION) | |
try: | |
write_frame(testdf, name='test_df', con=cn, flavor='mysql', if_exists='replace') | |
df_mysql = read_db('select * from '+name, con=cn) | |
print 'loaded dataframe from mysql', len(df_mysql) | |
finally: | |
cn.close() | |
print 'mysql done' | |
############################################################################## | |
if __name__=='__main__': | |
print """Aside from sqlite, you'll need to install the driver and set a valid | |
connection string for each test routine.""" | |
test_data = { | |
"name": [ 'Joe', 'Bob', 'Jim', 'Suzy', 'Cathy', 'Sarah' ], | |
"hire_date": [ datetime(2012,1,1), datetime(2012,2,1), datetime(2012,3,1), datetime(2012,4,1), datetime(2012,5,1), datetime(2012,6,1) ], | |
"erank": [ 1, 2, 3, 4, 5, 6 ], | |
"score": [ 1.1, 2.2, 3.1, 2.5, 3.6, 1.8] | |
} | |
df = DataFrame(test_data) | |
name='test_df' | |
test_sqlite(name, df) | |
#test_oracle(name, df) | |
#test_postgresql(name, df) | |
#test_mysql(name, df) | |
print 'done' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment