Last active
October 28, 2024 08:28
-
-
Save emmanuelnk/db62507184125ddfe24844bb552fc26d to your computer and use it in GitHub Desktop.
Python SQLAlchemy Basic Model, Session, DB Connection Classes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import event | |
import os | |
import logging | |
import sqlalchemy | |
import boto3 | |
import base64 | |
import json | |
from botocore.exceptions import ClientError | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
class DB: | |
__instance = None | |
def __init__(self): | |
""" Virtually private constructor. """ | |
if DB.__instance is not None: | |
raise Exception( | |
"This class is a singleton, use DB.create()") | |
else: | |
DB.__instance = self | |
self.engine = self.create_engine() | |
@staticmethod | |
def create(): | |
if DB.__instance is None: | |
DB.__instance = DB() | |
return DB.__instance | |
@staticmethod | |
def get_secret(secret_name): | |
client = boto3.client('secretsmanager') | |
try: | |
get_secret_value_response = client.get_secret_value( | |
SecretId=secret_name | |
) | |
except ClientError as e: | |
if e.response['Error']['Code'] == 'DecryptionFailureException': | |
raise e | |
elif e.response['Error']['Code'] == 'InternalServiceErrorException': | |
raise e | |
elif e.response['Error']['Code'] == 'InvalidParameterException': | |
raise e | |
elif e.response['Error']['Code'] == 'InvalidRequestException': | |
raise e | |
elif e.response['Error']['Code'] == 'ResourceNotFoundException': | |
raise e | |
else: | |
if 'SecretString' in get_secret_value_response: | |
secret = get_secret_value_response['SecretString'] | |
else: | |
secret = base64.b64decode(get_secret_value_response['SecretBinary']) | |
return json.loads(secret) | |
def get_credentials(): | |
""" Fetch credentials from either environment variables (for testing) or AWS Secret Manager""" | |
if os.getenv('SECRETSMANAGER_RDS_PG_ID') is None: | |
return { | |
'username': os.getenv('POSTGRESQL_USER', 'postgres'), | |
'password': os.getenv('POSTGRESQL_PASSWORD', 'some_password'), | |
'host': os.getenv('POSTGRESQL_HOST', 'localhost'), | |
'port': os.getenv('POSTGRESQL_PORT', 5432), | |
'database': os.getenv('POSTGRESQL_DATABASE', 'user_database'), | |
} | |
# get all access credentials from secrets manager | |
credentials = DB.get_secret(os.getenv('SECRETSMANAGER_RDS_PG_ID')) | |
return { | |
'username': credentials['username'], | |
'password': credentials['password'], | |
'host': credentials['host'], | |
'port': credentials['port'], | |
'database': credentials['dbname'], | |
} | |
def create_engine(self): | |
credentials = DB.get_credentials() | |
return sqlalchemy.create_engine('{engine}://{user}:{password}@{host}:{port}/{database}'.format( | |
engine='postgres+psycopg2', | |
user=credentials['username'], | |
password=credentials['password'], | |
host=credentials['host'], | |
port=int(credentials['port']), | |
database=credentials['database'] | |
), | |
pool_size=200, | |
max_overflow=0, | |
echo=bool(os.getenv('POSTGRESQL_DEBUG', False)) | |
) | |
def connect(self): | |
return self.engine.connect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from sqlalchemy import Column, Integer, String, DateTime, UniqueConstraint, func | |
from sqlalchemy.ext.declarative import declarative_base | |
from .db import DB | |
db = DB.create() | |
engine = db.engine | |
Base = declarative_base() | |
@dataclass | |
class User(Base): | |
__tablename__ = 'user' | |
# only one email can be attached to one id card | |
__table_args__ = (UniqueConstraint('email', 'id_card_no'),) | |
id = Column(Integer, primary_key=True) | |
first_name: str = Column(String) | |
last_name: str = Column(String) | |
id_card_no: str = Column(String) | |
email: str = Column(String) | |
created_at = Column(DateTime(timezone=True), default=func.now()) | |
updated_at = Column(DateTime(timezone=True), | |
default=func.now(), onupdate=func.now()) | |
# create table if it does not exist, if you change the model, | |
# you have to drop the table first for this code to alter it in the db | |
Base.metadata.create_all(engine) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy.orm import sessionmaker | |
from .session import SessionHandler | |
from .models import User | |
from .db import DB | |
db = DB.create() | |
engine = db.engine | |
Session = sessionmaker(bind=engine) | |
session = Session() | |
try: | |
user_session = SessionHandler.create(session, User) | |
# add a new record | |
user_session.add({ | |
"first_name": "john", | |
"last_name": "doe", | |
"id_card_no": "1234598765", | |
"email": "[email protected]" | |
}) | |
session.commit() | |
except Exception as e: | |
session.rollback() | |
raise e | |
finally: | |
session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import datetime | |
import time | |
from dataclasses import asdict | |
from sqlalchemy.dialects.postgresql import insert as pg_insert | |
from sqlalchemy import UniqueConstraint | |
from . import DB | |
class SchemaEncoder(json.JSONEncoder): | |
"""Encoder for converting Model objects into JSON.""" | |
def default(self, obj): | |
if isinstance(obj, datetime.date): | |
return time.strftime('%Y-%m-%dT%H:%M:%SZ', obj.utctimetuple()) | |
return json.JSONEncoder.default(self, obj) | |
class SessionHandler(): | |
__instance = None | |
def __init__(self, session, model): | |
""" Virtually private constructor. """ | |
SessionHandler.__instance = self | |
self.model = model | |
self.session = session | |
@staticmethod | |
def create(session, model): | |
SessionHandler.__instance = SessionHandler(session, model) | |
return SessionHandler.__instance | |
def add(self, record_dict): | |
record_model = self.model(**record_dict) | |
self.session.add(record_model) | |
def insert_many(self, record_list): | |
statements = [pg_insert(self.model).values(record_dict).on_conflict_do_nothing() for record_dict in record_list] | |
return [self.session.execute(statement) for statement in statements] | |
def add_many(self, record_list): | |
return self.session.add_all([self.model(**record_dict) for record_dict in record_list]) | |
def update(self, query_dict, update_dict): | |
return self.session.query(self.model).filter_by(**query_dict).update(update_dict) | |
def upsert(self, record_dict, set_dict, constraint): | |
statement = pg_insert(self.model).values(record_dict).on_conflict_do_update( | |
constraint=constraint, | |
set_= set_dict | |
) | |
return self.session.execute(statement) | |
def get(self, id, to_json=None): | |
result = self.session.query(self.model).get(id) | |
return asdict(result) if to_json is None else self.to_json(result) | |
def get_one(self, query_dict, to_json=None): | |
result = self.session.query(self.model).filter_by(**query_dict).first() | |
return asdict(result) if to_json is None else self.to_json(result) | |
def get_latest(self, query_dict, to_json=None): | |
result = self.session.query(self.model).filter_by(**query_dict).order_by(self.model.updated_at.desc()).first() | |
return None if result is None else (asdict(result) if to_json is None else self.to_json(result)) | |
def get_count(self, query_dict, to_json=None): | |
return self.session.query(self.model).filter_by(**query_dict).count() | |
def get_all(self, query_dict, to_json=None): | |
results = self.session.query(self.model).filter_by(**query_dict).all() | |
return [asdict(result) if to_json is None else self.to_json(result) for result in results] | |
def delete(self, query_dict): | |
return self.session.query(self.model).filter_by(**query_dict).delete() | |
def to_json(self, record_obj): | |
return json.dumps(asdict(record_obj), cls=SchemaEncoder, ensure_ascii=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think very nice implementation.
I need auto create model from database table like Entity Framework