Created
August 8, 2016 19:10
-
-
Save wmantly/62f55a024b996487e68dcf209ace42b9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sqlite3 | |
| conn = sqlite3.connect( 'babyorm.db' ) | |
| c = conn.cursor() | |
| class Model( dict ): | |
| def __init__(self, **kwargs): | |
| self.columns = self.__columns() | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| def __setitem__(self, key, value): | |
| if key in self.columns: | |
| return super().__setitem__(key, value) | |
| else: | |
| return False | |
| def __setattr__(self, key, value): | |
| self.__setitem__(key, value) | |
| return super().__setattr__(key, value) | |
| @staticmethod | |
| def __parse_args(obj, on_join=" AND "): # needs better name | |
| return on_join.join( | |
| "{} = '{}'".format(name, value) for name, value in obj.items() | |
| ) | |
| @classmethod | |
| def __row_to_dict(cls, row): | |
| return {column:row[index] for index, column in enumerate( cls.__columns() )} | |
| @classmethod | |
| def __columns(cls, force_update=False): | |
| if force_update or hasattr(cls, 'column'): | |
| return cls.column | |
| c.execute('pragma table_info("{}")'.format(cls.__name__)) | |
| cls.columns = tuple(d[1] for d in c) | |
| return cls.columns | |
| @classmethod | |
| def all(cls): | |
| c.execute( "SELECT * FROM {}".format(cls.__name__) ) | |
| return [cls(**cls.__row_to_dict(row)) for row in c.fetchall()] | |
| @classmethod | |
| def get(cls, *args, **kwargs): | |
| kwargs_ = {'id': args[0]} if args else kwargs | |
| c.execute("SELECT * FROM {} WHERE {} limit 1".format( | |
| cls.__name__, | |
| cls.__parse_args(kwargs_) | |
| )) | |
| return cls(**cls.__row_to_dict(c.fetchone())) | |
| @classmethod | |
| def filter(cls, **kwargs): | |
| c.execute("SELECT * FROM {} WHERE {} ".format( | |
| cls.__name__, | |
| cls.__parse_args(kwargs) | |
| )) | |
| return [cls(**cls.__row_to_dict(row)) for row in c.fetchall()] | |
| def save( self ): | |
| if 'id' in self: | |
| return self.update() | |
| else: | |
| return self.create() | |
| def create( self ): | |
| sql_string = "INSERT INTO {name} ({keys}) VALUES ({values})".format( | |
| name = self.__class__.__name__, | |
| keys = ','.join( key for key in self.keys() ), | |
| values = ','.join( "'{}'".format(value) for value in self.values() ) | |
| ) | |
| c.execute(sql_string) | |
| setattr(self, 'id', c.lastrowid) | |
| conn.commit() | |
| return self | |
| def update( self ): | |
| c.execute("UPDATE {name} SET {values} WHERE id={id}".format( | |
| name = self.__class__.__name__, | |
| values = self.__parse_args(self, ', '), | |
| id = self['id'] | |
| ) ) | |
| conn.commit() | |
| return self | |
| ###don't touch the code for these | |
| class Users(Model): | |
| pass | |
| class Stocks(Model): | |
| pass | |
| if __name__ == '__main__': | |
| pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment