Last active
August 4, 2024 23:33
-
-
Save MichaelCurrie/b5ab978c0c0c1860bb5e75676775b43b to your computer and use it in GitHub Desktop.
Fast pandas DataFrame read/write to mariadb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Drop-in replacement for pandas.DataFrame methods for reading and writing to database: | |
pandas.DataFrame.to_sql | |
pandas.read_sql_table | |
For some reason the current (Jan 2023) implementation in Pandas | |
using sqlalchemy is really slow. These methods are ~300x faster. | |
NOTE: Works only with mariadb's connector for now (pip3 install mariadb) | |
https://mariadb.com/resources/blog/how-to-connect-python-programs-to-mariadb/ | |
""" | |
import pandas as pd | |
import warnings | |
class DataFrameFast(pd.DataFrame): | |
def to_sql(self, name, con, if_exists='append', index=False, *args, **kwargs): | |
if not if_exists in ['replace', 'append']: | |
raise AssertionError( | |
"not if_exists in ['replace', 'append'] is not yet impemented") | |
# Truncate database table | |
# NOTE: Users may have to perform to_sql in the correct | |
# sequence to avoid causing foreign key errors with this step | |
if if_exists == 'replace': | |
with con.cursor() as cursor: | |
r = cursor.execute(f"TRUNCATE TABLE {name}") | |
# Prepare an INSERT which will populate the real mariadb table with df's data | |
# INSERT INTO table(c1,c2,...) VALUES (v11,v12,...), ... (vnn,vn2,...); | |
# If index, then we also want the index inserted | |
cols = [self.index.name] * index + list(self.columns) | |
cmd = (f"INSERT INTO {name} ({', '.join(cols)})" | |
f" VALUES ({', '.join(['?']*len(cols))})") | |
table_data = list(self.itertuples(index=index)) | |
# Replace nan with None for SQL to accept it. | |
table_data = [ | |
[None if pd.isnull(value) else value for value in sublist] | |
for sublist in table_data] | |
if len(table_data) == 0: | |
pass | |
else: | |
with con.cursor() as cursor: | |
cursor.executemany(cmd, table_data) | |
def column_info(self): | |
""" Returns the column information. | |
Parameters: | |
table_name: string. If None, returns column info for ALL tables. | |
""" | |
clauses = [f"TABLE_SCHEMA = '{self.database}'"] | |
if not table_name is None: | |
clauses.append(f"TABLE_NAME = '{table_name}'") | |
with con.cursor() as cursor: | |
cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS " | |
f"WHERE {'AND '.join(clauses)};") | |
records0 = custor.fetchall() | |
return pd.DataFrame(r) | |
def get_data(con, query): | |
with con.cursor() as cursor: | |
cursor.execute(query) | |
records0 = cursor.fetchall() | |
# Get the field names | |
fields = cursor.description | |
# Get a list of dicts with proper field names | |
# (i.e. records in the pandas sense) | |
return [ | |
{fields[i][0]:field_value for i, field_value in enumerate(v)} | |
for v in records0] | |
def read_sql_table(name, con, *args, **kwargs): | |
""" A drop-in replacement for pd.read_sql_table | |
""" | |
records = get_data(con, f"SELECT * FROM {name};") | |
if len(records) > 0: | |
df = DataFrameFast.from_records(records) | |
else: | |
if name.count('.') == 1: | |
table_schema, table_name = name.split('.') | |
query = (f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS " | |
f"WHERE TABLE_SCHEMA = '{table_schema}' AND " | |
f"TABLE_NAME = '{table_name}';") | |
else: | |
warnings.warn("Note: this assumes the table name " | |
f"{name} is unique across all databases") | |
query = (f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS " | |
f"WHERE TABLE_NAME = '{name}';") | |
# Make an empty dataframe with the right column names | |
column_info = get_data(con, query) | |
columns = list(pd.DataFrame(column_info)['COLUMN_NAME']) | |
df = DataFrameFast(columns=columns) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment