Last active
January 7, 2024 15:49
-
-
Save plockaby/6fcdf51148f8b7035b8f451ebe583a82 to your computer and use it in GitHub Desktop.
Python Flask Connection Pool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import psycopg2 | |
from psycopg2.extras import RealDictCursor | |
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN, TRANSACTION_STATUS_IDLE | |
from flask import g | |
import threading | |
import tenacity | |
import uuid | |
import pwd | |
import os | |
# we want to set up a separate logger | |
logger = logging.getLogger(__name__) | |
class PoolError(psycopg2.Error): | |
pass | |
class ConnectionPool: | |
def __init__(self, minconn, maxconn, *args, **kwargs): | |
self.minconn = int(minconn) | |
self.maxconn = int(maxconn) | |
self._args = args | |
self._kwargs = kwargs | |
self._pool = [] # connections that are available | |
self._used = {} # connections currently in use | |
# control access to the thread pool | |
self._lock = threading.RLock() | |
def getconn(self, key): | |
with self._lock: | |
# this key already has a connection so return it | |
if (key in self._used): | |
return self._used[key] | |
# our pool is currently empty | |
if (len(self._pool) == 0): | |
# we've given out all of the connections that we want to | |
if (len(self._used) == self.maxconn): | |
raise PoolError("connection pool exhausted") | |
# get a connection but do it with a retry | |
conn = self._connect() | |
# add to the list of available connections | |
self._pool.append(conn) | |
# take a connection out of the pool and give it away | |
self._used[key] = conn = self._pool.pop() | |
return conn | |
def putconn(self, key, close=False): | |
with self._lock: | |
conn = self.getconn(key) | |
if (conn is None): | |
raise PoolError("no connection with that key") | |
if (len(self._pool) < self.minconn and not close): | |
# Return the connection into a consistent state before putting | |
# it back into the pool | |
status = conn.info.transaction_status | |
if (status == TRANSACTION_STATUS_UNKNOWN): | |
# server connection lost | |
conn.close() | |
elif (status != TRANSACTION_STATUS_IDLE): | |
# connection in error or in transaction | |
conn.rollback() | |
self._pool.append(conn) | |
else: | |
# regular idle connection | |
self._pool.append(conn) | |
else: | |
conn.close() | |
# here we check for the presence of key because it can happen that | |
# a thread tries to put back a connection after a call to close | |
if (key in self._used): | |
del self._used[key] | |
# retry with a random value between every 0.5 and 1.5 seconds | |
@tenacity.retry(wait=tenacity.wait_fixed(0.5) + tenacity.wait_random(0, 1.5), before=tenacity.before_log(logger, logging.DEBUG)) | |
def _connect(self): | |
# connect to the database with the arguments provided when the pool was | |
# initialized. enable autocommit for consistency. this will retry using | |
# the "tenacity" library. | |
conn = psycopg2.connect(*self._args, **self._kwargs) | |
conn.autocommit = True | |
return conn | |
class DatabaseClient: | |
def __init__(self, app=None, **kwargs): | |
if (app is not None): | |
self.init_app(app, **kwargs) | |
else: | |
self.app = None | |
def init_app(self, app, key="default", minconn=2, maxconn=32, **kwargs): | |
""" | |
The key is a name for the connection. This allows you to build pools | |
for multiple databases. If you don't provide one then you can only | |
pool one database. | |
""" | |
self.app = app | |
# this is how we will find the database connection client identifier | |
# for this request. this lets the library ensure that it is handing out | |
# the same connection for the duration of the request. | |
self.key = "db_client_key[{}]".format(key) | |
# initialize the connection pool | |
self.pool = ConnectionPool( | |
minconn=minconn, | |
maxconn=maxconn, | |
cursor_factory=RealDictCursor, | |
**kwargs, | |
) | |
# this will clean up the connection when it is done | |
self.app.teardown_request(self.close) | |
def conn(self): | |
""" | |
This function should be used by your Flask views to get a connection | |
to the database. It will always return a valid connection and will | |
always return the same connection to the same request. It will only | |
throw an exception if the pool is full. | |
""" | |
# loop until we have a database connection | |
db_client = None | |
while (db_client is None): | |
# see if we have a database client identifier for this request | |
# already. if we have a client identifier then get the connection | |
# associated with that identifier and test if it is still alive. if | |
# it is alive then return it. if it is not alive then raise an | |
# exception because we want to return the same connection through | |
# an entire request. if we do NOT have a client identifier then | |
# get a connection and test it until we get a connection that is | |
# alive. | |
db_client_id = None | |
if (hasattr(g, self.key)): | |
# try to get a connection with this client id | |
db_client_id = str(getattr(g, self.key)) | |
db_client = self._get_connection(db_client_id) | |
# no connection returned for the request's client identifier so | |
# the connection is dead and we can't do anything. | |
if (db_client is None): | |
delattr(g, self.key) # remove client identifier | |
raise PoolError("request connection lost") | |
# actually the client identifier returned a valid connection | |
return db_client | |
# try to get a connection with a new identifier | |
db_client_id = str(uuid.uuid4()) | |
db_client = self._get_connection(db_client_id) | |
# the connection that we got was valid so let's save the identifier | |
# and return the connection. (if it wasn't valid then we'll just | |
# repeate the loop which is a-ok.) | |
if (db_client is not None): | |
# do anything with a new connection here. for example, maybe | |
# you want to set a configuration value that use the person's | |
# username in it. i don't know. | |
# TODO | |
# then attach the connection to the request global | |
setattr(g, self.key, db_client_id) | |
return db_client | |
def close(self, exception): | |
# this gets called when a request is finished, regardless of the state | |
# of the request (e.g. success [2xx] or failure [4xx, 5xx]) | |
if (hasattr(g, self.key)): | |
try: | |
db_client_id = getattr(g, self.key) | |
self.pool.putconn(db_client_id) | |
logger.debug("returned connection {} to pool named {}".format(db_client_id, self.key)) | |
except (PoolError, KeyError) as e: | |
logger.error("could not return connection to pool: {}".format(repr(e))) | |
def _get_connection(self, db_client_id): | |
db_client = self.pool.getconn(db_client_id) | |
try: | |
logger.debug("testing connection {} from pool named {}".format(db_client_id, self.key)) | |
# test the connection before giving it back to ensure it works. | |
# if it doesn't work then we're going to close it and try to | |
# get a different connection until we find one that works. | |
cur = db_client.cursor() | |
cur.execute("SELECT pg_backend_pid()") | |
cur.close() | |
except Exception as e: | |
logger.warning("connection {} from pool named {} failed: {}".format(db_client_id, self.key, e)) | |
# we do not have a valid connection so put it back and close it | |
# and set our current db_client to None so that our next time | |
# around the loop will attempt to get a new connection. | |
self.pool.putconn(db_client_id, close=True) | |
# the connection was bad | |
return | |
else: | |
logger.debug("using connection {} from pool named {}".format(db_client_id, self.key)) | |
# the connection was good | |
return db_client |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I like the idea of this code, but since I have little experience with Flask outside a SQLAlchemy based application I am having a little trouble imagining where I would get a connection from the pool with an @before_request? and how to use it in a blueprint... Can you give us a hint? Thanks!