Created
February 12, 2023 21:53
-
-
Save alkemann/61b9bcfab5a98504dc678d1c04b71c51 to your computer and use it in GitHub Desktop.
Dataclass with db in basemodel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from contextlib import contextmanager | |
from util import BaseModel | |
from dataclasses import dataclass | |
from util import get_db | |
import dataclasses | |
from typing import Iterable | |
import logging | |
log = logging.getLogger(__file__) | |
@dataclass | |
class Blog(BaseModel): | |
id: int | |
name: str | |
body: str | |
__table__ = 'blogs' | |
@contextmanager | |
def get_db(file_name: str = "database.db"): | |
connection = sqlite3.connect(file_name) | |
try: | |
cursor = connection.cursor() | |
yield cursor | |
finally: | |
connection.commit() | |
connection.close() | |
class BaseModel: | |
__pk__ = 'id' | |
@classmethod | |
def fields(cls) -> str: | |
return ", ".join([f.name for f in dataclasses.fields(cls)]) | |
@classmethod | |
def fields_to_update(cls) -> str: | |
return ", ".join([f"{f.name} = ?" for f in dataclasses.fields(cls)]) | |
@classmethod | |
def getById(cls, id: int): | |
""" Returns one instance, matching id""" | |
fields = cls.fields() | |
query = f"SELECT {fields} FROM {cls.__table__} WHERE {cls.__pk__} = ?" | |
params = [id] | |
log.debug(query) | |
log.debug(params) | |
with get_db() as db: | |
result = db.execute(query, params) | |
row = result.fetchone() | |
if row is not None: | |
return cls(*row) | |
else: | |
raise Exception(f"Cound not find {id} of type {cls.__name__}") | |
@classmethod | |
def list(cls) -> Iterable: | |
""" Returns list with instances of all records """ | |
fields = cls.fields() | |
query = f"SELECT {fields} FROM {cls.__table__}" | |
log.debug(query) | |
with get_db() as db: | |
result = db.execute(query) | |
for row in result: | |
yield cls(*row) | |
def update(self) -> bool: | |
""" UPDATE table """ | |
params = list(self.__dict__.values()) # keep as dict for PostgreSQL | |
params.append(self.id) | |
query = f"UPDATE {self.__table__} SET {self.fields_to_update()} WHERE {self.__pk__} = ?" | |
log.debug(query) | |
log.debug(params) | |
with get_db() as db: | |
result = db.execute(query, params) | |
return result.rowcount == 1 | |
def insert(self) -> bool: | |
""" Insert into table """ | |
fields = self.fields() | |
params = list(self.__dict__.values()) # keep as dict for PostgreSQL | |
query = f"INSERT INTO {self.__table__} ({fields}) VALUES (?, ?, ?)" | |
log.debug(query) | |
log.debug(params) | |
with get_db() as db: | |
result = db.execute(query, params) | |
return result.rowcount == 1 | |
def getHasMany(self, model, join_table: str, other_foreign_key: str, foreign_key: str) -> Iterable: | |
""" Method for grabbing many related records that matches a has many relationship """ | |
query = f"SELECT {other_foreign_key} FROM {join_table} WHERE {foreign_key} = ?" | |
params = [getattr(self, self.__pk__)] | |
log.debug(query) | |
log.debug(params) | |
with get_db() as db: | |
results = db.execute(query, params) | |
for row in results: | |
yield model.getById(row[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment