Created
November 23, 2024 11:57
-
-
Save Rydgel/9176b3c659d29fa7e9619019c6839ff7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from flask import Flask, request | |
import jwt | |
from sqlalchemy import text | |
from sqlalchemy.orm import sessionmaker, declarative_base | |
from sqlalchemy import Column, UUID, String | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import Session | |
from sqlalchemy import event | |
from sqlalchemy import select | |
import os | |
app = Flask(__name__) | |
# Supabase connection configuration | |
SUPABASE_URL = "your_supabase_url" | |
SUPABASE_DB_URL = f"postgresql://postgres:{os.getenv('SUPABASE_KEY')}@db.{SUPABASE_URL}:5432/postgres" | |
# SQLAlchemy setup | |
Base = declarative_base() | |
class User(Base): | |
__tablename__ = 'users' | |
id = Column(UUID, primary_key=True) | |
first_name = Column(String) | |
last_name = Column(String, unique=True) | |
# JWTSession with automatic re-application of request.jwt.claims | |
class JWTSession(Session): | |
def __init__(self, bind, jwt_token, **kwargs): | |
super().__init__(bind=bind, **kwargs) | |
self.jwt_claims = json.dumps(jwt_token) | |
# Hook to reapply SET LOCAL on transaction start | |
def set_jwt_claims(session, connection, claims): | |
connection.execute(text(f""" | |
SET LOCAL ROLE authenticated; | |
SELECT set_config('request.jwt.claims', '{claims}', true); | |
""")) | |
# Configure the session factory | |
def get_jwt_session(jwt_token): | |
engine = create_engine(SUPABASE_DB_URL, echo=True) | |
session_factory = sessionmaker(bind=engine, class_=JWTSession) | |
session = session_factory(jwt_token=jwt_token) | |
@event.listens_for(engine, "begin") | |
def on_begin_transaction(conn): | |
set_jwt_claims(session, conn, session.jwt_claims) | |
return session | |
@app.route('/users', methods=['GET']) | |
def get_users(): | |
auth_header = request.headers.get('Authorization') | |
jwt_token = jwt.decode(auth_header[7:], "your_jwt_token_secret", audience="authenticated", algorithms="HS256") | |
with get_jwt_session(jwt_token) as session: | |
query = session.scalars(select(User)) | |
# should only be 1 user | |
users = query.all() | |
for user in users: | |
user.first_name = "Jerome" | |
session.commit() | |
# we still get the auth applied after a commit | |
query = session.scalars(select(User)) | |
users = query.all() | |
return [{'id': user.id, 'first_name': user.first_name, 'last_name': user.last_name} for user in users] | |
if __name__ == '__main__': | |
app.run(debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment