Created
February 7, 2022 12:37
-
-
Save ant1fact/810e7e2dee3b362e4a08267dcc3ae69d to your computer and use it in GitHub Desktop.
Easily get a list of required fields of a Flask-SQLAlchemy model, i.e. where db.Column(..., nullable=False)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask_sqlalchemy import Model, SQLAlchemy | |
from sqlalchemy import inspect | |
# Extend the base Model, for more details see: | |
# https://flask-sqlalchemy.palletsprojects.com/en/2.x/customizing/#model-class | |
class ExtendedModel(Model): | |
@classmethod | |
def required_fields(cls) -> list: | |
'''Returns column names for the Model where nullable=False''' | |
mapper = inspect(cls) | |
# Define column names that shouldn't appear in the final list | |
not_required = {'id',} | |
return [ | |
c.name | |
for c in mapper.columns | |
if not c.nullable and c.name not in not_required | |
] | |
db = SQLAlchemy(model_class=ExtendedModel) | |
class Item(db.Model): | |
id = db.Column(db.Integer, primary_key=True) # Required by default | |
name = db.Column(db.String(50), nullable=False) # Required | |
description = db.Column(db.String(250)) # Not required | |
if __name__ == '__main__': | |
# The printed list will only contain one item: ['name'] | |
print(Item.required_fields()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment