Created
March 24, 2017 18:18
-
-
Save lozhn/3229cfc771c2ab388486c3c70f21eb92 to your computer and use it in GitHub Desktop.
DMD course. Base activerecord-like python model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import psycopg2 | |
from collections import namedtuple | |
""" | |
Some advices to use Model class for interaction with user_tables | |
For example we have User class that derived from Model | |
We will fetch data from table with name __classname__ + 's' | |
User.all() | |
returns list of User objects for all records in table | |
User.get(exact_id) | |
returns object of User object with exact id | |
User.getBy(**kwargs) | |
returns list of User objects | |
user.save() or User.new(obj) | |
save or update record from table with the same id as user.id | |
user.drop() | |
removes records from table with the same id as user.id | |
User.deleteBy(**kwargs) | |
removes all record with predefined parameters in kwargs | |
""" | |
query_column_name = """ | |
SELECT column_name, data_type | |
FROM information_schema.columns | |
WHERE table_name = '{}' | |
""" | |
Result = namedtuple("Result", ['success', 'error']) | |
size = 10 | |
def db_connect(query): | |
def wrapper(cls, *args, **kwargs): | |
if not cls.names: | |
cls.connect() | |
return query(cls, *args, **kwargs) | |
return wrapper | |
class Model: | |
names = [] | |
data_types = [] | |
table_name = None | |
connection = None | |
@classmethod | |
def connect(cls): | |
if not cls.table_name: | |
cls.table_name = cls.__name__ + 's' | |
cur = cls.connection.cursor() | |
cur.execute(query_column_name.format(cls.table_name.lower())) | |
rows = cur.fetchall() | |
cls.names = [row[0] for row in rows] | |
cls.data_types = dict(rows) | |
@classmethod | |
@db_connect | |
def get(cls, id): | |
cur = cls.connection.cursor() | |
cur.execute("SELECT * FROM {} WHERE id = {}".format(cls.table_name, id)) | |
row = cur.fetchone() | |
cur.close() | |
if row: | |
return cls.getObject(row) | |
@classmethod | |
@db_connect | |
def getObject(cls, row): | |
return cls(**dict(zip(cls.names, row))) | |
@classmethod | |
@db_connect | |
def getBy(cls, **kwargs): | |
cur = cls.connection.cursor() | |
cur.execute("SELECT * FROM {} WHERE {}".format(cls.table_name, cls.where_str(**kwargs))) | |
rows = cur.fetchall() | |
cur.close() | |
return [cls.getObject(row) for row in rows] | |
@classmethod | |
@db_connect | |
def fetchBy(cls, page=1, **kwargs): | |
if page < 1: | |
page = 1 | |
cur = cls.connection.cursor() | |
cur.execute("SELECT * FROM {} WHERE {}".format(cls.table_name, cls.where_str(**kwargs))) | |
try: | |
cur.scroll(size * (page - 1)) | |
rows = cur.fetchmany(size) | |
except psycopg2.ProgrammingError: | |
rows = [] | |
cur.close() | |
return [cls.getObject(row) for row in rows] | |
@classmethod | |
@db_connect | |
def all(cls): | |
cur = cls.connection.cursor() | |
cur.execute("SELECT * FROM {}".format(cls.table_name)) | |
rows = cur.fetchall() | |
return [cls.getObject(row) for row in rows] | |
@classmethod | |
@db_connect | |
def fetch_page(cls, num=1): | |
cur = cls.connection.cursor() | |
if num < 1: | |
num = 1 | |
cur.execute("SELECT * FROM {}".format(cls.table_name)) | |
try: | |
cur.scroll(size * (num - 1)) | |
rows = cur.fetchmany(size) | |
except psycopg2.ProgrammingError: | |
rows = [] | |
return [cls.getObject(row) for row in rows] | |
@classmethod | |
def where_str(cls, **kwargs): | |
s = "" | |
for key in kwargs: | |
s += "{} = {} and ".format(key, sql_str_value(kwargs[key])) | |
return s[: -len('and') - 1] | |
@classmethod | |
def insert_str(cls, **kwargs): | |
names = "( " | |
values = "( " | |
for key in kwargs: | |
if kwargs[key] != None: | |
names += key + ", " | |
values += sql_str_value(kwargs[key]) + ", " | |
names = names[:-2] + ")" | |
values = values[:-2] + ")" | |
return names, values | |
@classmethod | |
def update_str(cls, **kwargs): | |
names = "( " | |
values = "( " | |
for key in kwargs: | |
if key != 'id' and kwargs[key]: | |
names += key + ", " | |
values += sql_str_value(kwargs[key]) + ", " | |
names = names[:-2] + ")" | |
values = values[:-2] + ")" | |
return names, values | |
@classmethod | |
@db_connect | |
def new(cls, obj=None, **kwargs): | |
# names_values = {} | |
if isinstance(obj, cls): | |
names_values = obj.__dict__ | |
else: | |
names_values = kwargs | |
if not names_values: | |
return | |
try: | |
cur = cls.connection.cursor() | |
cur.execute("INSERT INTO {} {} VALUES {}".format(cls.table_name, *cls.insert_str(**names_values))) | |
cls.connection.commit() | |
return Result(True, None) | |
except psycopg2.IntegrityError as e: | |
cls.connection.rollback() | |
return Result(False, e) | |
finally: | |
cur.close() | |
@classmethod | |
@db_connect | |
def already_exists(cls, obj): | |
if not obj.id: | |
return False | |
for db_obj in cls.all(): | |
try: | |
if db_obj.id == obj.id: | |
return True | |
except AttributeError as e: | |
print(e) | |
else: | |
return False | |
@classmethod | |
@db_connect | |
def update(cls, obj): | |
try: | |
cur = cls.connection.cursor() | |
cur.execute("UPDATE {} SET {} = {} WHERE id = {} ".format(cls.table_name, *cls.update_str(**obj.__dict__), obj.id)) | |
cur.close() | |
cls.connection.commit() | |
return Result(True, None) | |
except psycopg2.IntegrityError as e: | |
cls.connection.rollback() | |
return Result(False, e) | |
finally: | |
cur.close() | |
@classmethod | |
@db_connect | |
def deleteBy(cls, **kwargs): | |
try: | |
cur = cls.connection.cursor() | |
cur.execute("DELETE FROM {} WHERE {};".format(cls.table_name, cls.where_str(**kwargs))) | |
cls.connection.commit() | |
return Result(True, None) | |
except psycopg2.IntegrityError as e: | |
cls.connection.rollback() | |
return Result(False, e) | |
finally: | |
cur.close() | |
@db_connect | |
def __init__(self, **kwargs): | |
for name in self.names: | |
try: | |
self.__dict__[name] = kwargs[name] # , dtype=self.data_types[name] | |
except KeyError: | |
self.__dict__[name] = None | |
def save(self): | |
if self.already_exists(obj=self): | |
return self.update(obj=self) | |
else: | |
return self.new(obj=self) | |
def drop(self): | |
return self.deleteBy(id=self.id) | |
# def parse_type(self, value, dtype): | |
# try: | |
# if dtype == 'integer': | |
# return int(value) | |
# elif dtype == 'boolean': | |
# return bool(value) | |
# elif dtype == 'date': | |
# print(type(value)) | |
# return datetime.datetime.strptime(value, "%Y-%m-%d") | |
# else: | |
# return value | |
# except TypeError: | |
# print("ORM: Parse value error for type '%s' and value '%s'" % (dtype, value)) | |
# return value | |
def __repr__(self): | |
s = "{" + self.__class__.__name__ + " " | |
for name in self.names: | |
s += "{} = {}, ".format(name, sql_str_value(self.__dict__[name])) | |
s = s[:-2] + "}" | |
return s | |
def sql_str_value(value): | |
if not (isinstance(value, int) or isinstance(value, bool)): | |
return "'%s'" % value | |
else: | |
return str(value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment