Created
September 29, 2022 18:30
-
-
Save johndoe46/40c0f993a641fe7a3eb74bd6d67af7da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from pkg_resources import parse_version | |
__version__ = "0.2.1" | |
class SQLiteWriter(object): | |
""" | |
In frozen mode (the default), the writer will not alter db schema. | |
Just add frozen=False to enable column creation (or just add False | |
as second parameter): | |
query_writer = SQLiteWriter(":memory:", False) | |
""" | |
def __init__(self, db_path=":memory:", frozen=True): | |
self.db = sqlite3.connect(db_path) | |
self.db.isolation_level = None | |
self.db.row_factory = sqlite3.Row | |
self.frozen = frozen | |
self.cursor = self.db.cursor() | |
self.cursor.execute("PRAGMA foreign_keys=ON;") | |
self.cursor.execute('PRAGMA encoding = "UTF-8";') | |
self.cursor.execute('BEGIN;') | |
def __del__(self): | |
self.db.close() | |
def replace(self, bean): | |
keys = [] | |
values = [] | |
write_operation = "replace" | |
if "id" not in bean.__dict__: | |
write_operation = "insert" | |
keys.append("id") | |
values.append(None) | |
self.__create_table(bean.__class__.__name__) | |
columns = self.__get_columns(bean.__class__.__name__) | |
for key in bean.__dict__: | |
keys.append(key) | |
if key not in columns: | |
self.__create_column(bean.__class__.__name__, key, | |
type(bean.__dict__[key])) | |
values.append(bean.__dict__[key]) | |
sql = write_operation + " into " + bean.__class__.__name__ + "(" | |
sql += ",".join(keys) + ") values (" | |
sql += ",".join(["?" for i in keys]) + ")" | |
self.cursor.execute(sql, values) | |
if write_operation == "insert": | |
bean.id = self.cursor.lastrowid | |
return bean.id | |
def __create_column(self, table, column, sqltype): | |
if self.frozen: | |
return | |
if sqltype in [int, complex, float, int, bool]: | |
sqltype = "NUMERIC" | |
else: | |
sqltype = "TEXT" | |
sql = "alter table " + table + " add " + column + " " + sqltype | |
self.cursor.execute(sql) | |
def __get_columns(self, table): | |
columns = [] | |
if self.frozen: | |
return columns | |
self.cursor.execute("PRAGMA table_info(" + table + ")") | |
for row in self.cursor: | |
columns.append(row["name"]) | |
return columns | |
def __create_table(self, table): | |
if self.frozen: | |
return | |
sql = "create table if not exists " + table + "(id INTEGER PRIMARY KEY AUTOINCREMENT)" | |
self.cursor.execute(sql) | |
def get_rows(self, table_name, sql = "1", replace = None): | |
if replace is None : replace = [] | |
self.__create_table(table_name) | |
sql = "SELECT * FROM " + table_name + " WHERE " + sql | |
try: | |
self.cursor.execute(sql, replace) | |
for row in self.cursor: | |
yield row | |
except sqlite3.OperationalError: | |
return | |
def get_count(self, table_name, sql="1", replace = None): | |
if replace is None : replace = [] | |
self.__create_table(table_name) | |
sql = "SELECT count(*) AS cnt FROM " + table_name + " WHERE " + sql | |
try: | |
self.cursor.execute(sql, replace) | |
except sqlite3.OperationalError: | |
return 0 | |
for row in self.cursor: | |
return row["cnt"] | |
def delete(self, bean): | |
self.__create_table(bean.__class__.__name__) | |
sql = "delete from " + bean.__class__.__name__ + " where id=?" | |
self.cursor.execute(sql,[bean.id]) | |
def link(self, bean_a, bean_b): | |
self.replace(bean_a) | |
self.replace(bean_b) | |
table_a = bean_a.__class__.__name__ | |
table_b = bean_b.__class__.__name__ | |
assoc_table = self.__create_assoc_table(table_a, table_b) | |
sql = "replace into " + assoc_table + "(" + table_a + "_id," + table_b | |
sql += "_id) values(?,?)" | |
self.cursor.execute(sql, | |
[bean_a.id, bean_b.id]) | |
def unlink(self, bean_a, bean_b): | |
table_a = bean_a.__class__.__name__ | |
table_b = bean_b.__class__.__name__ | |
assoc_table = self.__create_assoc_table(table_a, table_b) | |
sql = "delete from " + assoc_table + " where " + table_a | |
sql += "_id=? and " + table_b + "_id=?" | |
self.cursor.execute(sql, | |
[bean_a.id, bean_b.id]) | |
def get_linked_rows(self, bean, table_name): | |
bean_table = bean.__class__.__name__ | |
assoc_table = self.__create_assoc_table(bean_table, table_name) | |
sql = "select t.* from " + table_name + " t inner join " + assoc_table | |
sql += " a on a." + table_name + "_id = t.id where a." | |
sql += bean_table + "_id=?" | |
self.cursor.execute(sql,[bean.id]) | |
for row in self.cursor: | |
yield row | |
def __create_assoc_table(self, table_a, table_b): | |
assoc_table = "_".join(sorted([table_a, table_b])) | |
if not self.frozen: | |
sql = "create table if not exists " + assoc_table + "(" | |
sql+= table_a + "_id NOT NULL REFERENCES " + table_a + "(id) ON DELETE cascade," | |
sql+= table_b + "_id NOT NULL REFERENCES " + table_b + "(id) ON DELETE cascade," | |
sql+= " PRIMARY KEY (" + table_a + "_id," + table_b + "_id));" | |
self.cursor.execute(sql) | |
# no real support for foreign keys until sqlite3 v3.6.19 | |
# so here's the hack | |
if cmp(parse_version(sqlite3.sqlite_version),parse_version("3.6.19")) < 0: | |
sql = "create trigger if not exists fk_" + table_a + "_" + assoc_table | |
sql+= " before delete on " + table_a | |
sql+= " for each row begin delete from " + assoc_table + " where " + table_a + "_id = OLD.id;end;" | |
self.cursor.execute(sql) | |
sql = "create trigger if not exists fk_" + table_b + "_" + assoc_table | |
sql+= " before delete on " + table_b | |
sql+= " for each row begin delete from " + assoc_table + " where " + table_b + "_id = OLD.id;end;" | |
self.cursor.execute(sql) | |
return assoc_table | |
def delete_all(self, table_name, sql = "1", replace = None): | |
if replace is None : replace = [] | |
self.__create_table(table_name) | |
sql = "DELETE FROM " + table_name + " WHERE " + sql | |
try: | |
self.cursor.execute(sql, replace) | |
return True | |
except sqlite3.OperationalError: | |
return False | |
def commit(self): | |
self.db.commit() | |
class Store(object): | |
""" | |
A SQL writer should be passed to the constructor: | |
beans_save = Store(SQLiteWriter(":memory"), frozen=False) | |
""" | |
def __init__(self, SQLWriter): | |
self.writer = SQLWriter | |
def new(self, table_name): | |
new_object = type(table_name,(object,),{})() | |
return new_object | |
def save(self, bean): | |
self.writer.replace(bean) | |
def load(self, table_name, id): | |
for row in self.writer.get_rows(table_name, "id=?", [id]): | |
return self.row_to_object(table_name, row) | |
def count(self, table_name, sql = "1", replace=None): | |
return self.writer.get_count(table_name, sql, replace if replace is not None else []) | |
def find(self, table_name, sql = "1", replace=None): | |
for row in self.writer.get_rows(table_name, sql, replace if replace is not None else []): | |
yield self.row_to_object(table_name, row) | |
def find_one(self, table_name, sql = "1", replace=None): | |
try: | |
return next(self.find(table_name, sql, replace)) | |
except StopIteration: | |
return None | |
def delete(self, bean): | |
self.writer.delete(bean) | |
def link(self, bean_a, bean_b): | |
self.writer.link(bean_a, bean_b) | |
def unlink(self, bean_a, bean_b): | |
self.writer.unlink(bean_a, bean_b) | |
def get_linked(self, bean, table_name): | |
for row in self.writer.get_linked_rows(bean, table_name): | |
yield self.row_to_object(table_name, row) | |
def delete_all(self, table_name, sql = "1", replace=None): | |
return self.writer.delete_all(table_name, sql, replace if replace is not None else []) | |
def row_to_object(self, table_name, row): | |
new_object = type(table_name,(object,),{})() | |
for key in list(row.keys()): | |
new_object.__dict__[key] = row[key] | |
return new_object | |
def commit(self): | |
self.writer.commit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment