Last active
July 8, 2024 12:46
-
-
Save danielrichman/6046307 to your computer and use it in GitHub Desktop.
nicer postgres connection class & flask postgres
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PostgreSQLConnection(psycopg2.extensions.connection): | |
""" | |
A custom `connection_factory` for :func:`psycopg2.connect`. | |
This | |
* puts the connection into unicode mode (for text) | |
* modifies the :meth:`cursor` method of a :class:`psycopg2.connection`, | |
facilitating easy acquiring of cursors made from | |
:cls:`psycopg2.extras.RealDictCursor`. | |
""" | |
# this may be omitted in py3k | |
def __init__(self, *args, **kwargs): | |
super(PostgreSQLConnection, self).__init__(*args, **kwargs) | |
for type in (psycopg2.extensions.UNICODE, | |
psycopg2.extensions.UNICODEARRAY): | |
psycopg2.extensions.register_type(type, self) | |
def cursor(self, real_dict_cursor=False): | |
""" | |
Get a new cursor. | |
If real_dict_cursor is set, a RealDictCursor is returned | |
""" | |
kwargs = {} | |
if real_dict_cursor: | |
kwargs["cursor_factory"] = psycopg2.extras.RealDictCursor | |
return super(PostgreSQLConnection, self).cursor(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import unicode_literals | |
import logging | |
import threading | |
import flask | |
from werkzeug.local import LocalProxy | |
import psycopg2 | |
import psycopg2.extras | |
import psycopg2.extensions | |
postgres = LocalProxy(lambda: flask.current_app.postgres) | |
class PostgreSQL(object): | |
""" | |
A PostgreSQL helper extension for Flask apps | |
On initialisation it adds an after_request function that commits the | |
transaction (so that if the transaction rolls back the request will | |
fail) and a app context teardown function that disconnects any active | |
connection. | |
You can of course (and indeed should) use :meth:`commit` if you need to | |
ensure some changes have made it to the database before performing | |
some other action. :meth:`teardown` is also available to be called | |
directly. | |
Connections are created by ``psycopg2.connect(**app.config["POSTGRES"])`` | |
(e.g., ``app.config["POSTGRES"] = {"database": "mydb"}``), | |
are pooled (you can adjust the pool size with `pool`) and are tested for | |
server shutdown before being given to the request. | |
""" | |
def __init__(self, app=None, pool_size=2): | |
self.app = app | |
self._pool = [] | |
self.pool_size = pool_size | |
self._lock = threading.RLock() | |
self.logger = logging.getLogger(__name__ + ".PostgreSQL") | |
if app is not None: | |
self.init_app(app) | |
def init_app(self, app): | |
""" | |
Initialises the app by adding hooks | |
* Hook: ``app.after_request(self.commit)`` | |
* Hook: ``app.teardown_appcontext(self.teardown)`` | |
""" | |
app.after_request(self.commit) | |
app.teardown_appcontext(self.teardown) | |
app.postgresql = self | |
def _connect(self): | |
"""Returns a connection to the database""" | |
with self._lock: | |
c = None | |
if len(self._pool): | |
c = self._pool.pop() | |
try: | |
# This tests if the connection is still alive. | |
c.reset() | |
except psycopg2.OperationalError: | |
self.logger.debug("assuming pool dead", exc_info=True) | |
# assume that the entire pool is dead | |
try: | |
c.close() | |
except psycopg2.OperationalError: | |
pass | |
for c in self._pool: | |
try: | |
c.close() | |
except psycopg2.OperationalError: | |
pass | |
self._pool = [] | |
c = None | |
else: | |
self.logger.debug("got connection from pool") | |
if c is None: | |
c = self._new_connection() | |
return c | |
def _new_connection(self): | |
"""Create a new connection to the database""" | |
s = flask.current_app.config["POSTGRES"] | |
summary = ' '.join(k + "=" + v for k, v in s.iteritems()) | |
self.logger.debug("connecting (%s)", summary) | |
c = psycopg2.connect(connection_factory=PostgreSQLConnection, **s) | |
return c | |
@property | |
def connection(self): | |
""" | |
Gets the PostgreSQL connection for this Flask request | |
If no connection has been used in this request, it connects to the | |
database. Further use of this property will reference the same | |
connection | |
The connection is committed and closed at the end of the request. | |
""" | |
g = flask.g | |
if not hasattr(g, '_postgresql'): | |
g._postgresql = self._connect() | |
return g._postgresql | |
def cursor(self, real_dict_cursor=False): | |
""" | |
Get a new postgres cursor for immediate use during a request | |
If a cursor has not yet been used in this request, it connects to the | |
database. Further cursors re-use the per-request connection. | |
The connection is committed and closed at the end of the request. | |
If real_dict_cursor is set, a RealDictCursor is returned | |
""" | |
return self.connection.cursor(real_dict_cursor) | |
def commit(self, response=None): | |
""" | |
(Almost an) alias for self.connection.commit() | |
... except if self.connection has never been used this is a noop | |
(i.e., it does nothing) | |
Returns `response` unmodified, so that this may be used as an | |
:meth:`flask.after_request` function. | |
""" | |
g = flask.g | |
if hasattr(g, '_postgresql'): | |
self.logger.debug("committing") | |
g._postgresql.commit() | |
return response | |
def teardown(self, exception): | |
"""Either return the connection to the pool or close it""" | |
g = flask.g | |
if hasattr(g, '_postgresql'): | |
c = g._postgresql | |
del g._postgresql | |
with self._lock: | |
s = len(self._pool) | |
if s >= self.pool_size: | |
self.logger.debug("teardown: pool size %i - closing", s) | |
c.close() | |
else: | |
self.logger.debug("teardown: adding to pool, new size %i", | |
s + 1) | |
c.reset() | |
self._pool.append(c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import psycopg2 | |
import psycopg2.extras | |
import psycopg2.extensions | |
import flask | |
class MockConnectionBase(object): | |
def __init__(self): | |
# prevent connection setup | |
self.registered_types = [] | |
pass | |
def cursor(self, cursor_factory=None): | |
assert cursor_factory is None or \ | |
cursor_factory == psycopg2.extras.RealDictCursor | |
if cursor_factory is None: | |
return "stubbed cursor" | |
else: | |
return "stubbed dict cursor" | |
class ConnectionRebaser(type): | |
def __new__(mcs, name, bases, dict): | |
bases += (MockConnectionBase, ) | |
return type.__new__(mcs, name, bases, dict) | |
def mro(cls): | |
return (cls, utils.PostgreSQLConnection, MockConnectionBase) + \ | |
utils.PostgreSQLConnection.__mro__[1:] | |
class RebasedPostgreSQLConnection(utils.PostgreSQLConnection): | |
__metaclass__ = ConnectionRebaser | |
class FakeExtensions(object): | |
UNICODE = psycopg2.extensions.UNICODE | |
UNICODEARRAY = psycopg2.extensions.UNICODEARRAY | |
@classmethod | |
def register_type(self, what, connection): | |
connection.registered_types.append(what) | |
class TestPostgreSQLConnection(object): | |
def setup(self): | |
assert psycopg2.extensions.__name__ == "psycopg2.extensions" | |
self.fakes = FakeExtensions() | |
self.original_extensions = psycopg2.extensions | |
psycopg2.extensions = self.fakes | |
def teardown(self): | |
assert isinstance(psycopg2.extensions, FakeExtensions) | |
psycopg2.extensions = self.original_extensions | |
def test_only_affects_cursor(self): | |
assert [x for x in utils.PostgreSQLConnection.__dict__ | |
if not x.startswith("__")] == ["cursor"] | |
def test_cursor(self): | |
c = RebasedPostgreSQLConnection() | |
assert c.cursor() == "stubbed cursor" | |
assert c.cursor(False) == "stubbed cursor" | |
assert c.cursor(True) == "stubbed dict cursor" | |
def test_register_types(self): | |
c = RebasedPostgreSQLConnection() | |
assert c.registered_types == [psycopg2.extensions.UNICODE, | |
psycopg2.extensions.UNICODEARRAY] | |
class FakePsycopg2(object): | |
class connection(object): | |
class _cursor(object): | |
def __init__(self, connection, real_dict_cursor): | |
self.queries = [] | |
self.connection = connection | |
self.real_dict_cursor = real_dict_cursor | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
pass | |
def execute(self, query, args=None): | |
if self.connection.calls["close"] != 0: | |
raise psycopg2.OperationalError | |
self.queries.append((query, args)) | |
close_error = False | |
autocommit = False | |
def __init__(self, **settings): | |
self.settings = settings | |
self.types = [] | |
self.cursors = [] | |
self.calls = {"commit": 0, "reset": 0, "close": 0} | |
def cursor(self, real_dict_cursor=False): | |
c = self._cursor(self, real_dict_cursor) | |
self.cursors.append(c) | |
return c | |
def commit(self): | |
assert self.calls["close"] == 0 | |
self.calls["commit"] += 1 | |
def reset(self): | |
self.calls["reset"] += 1 | |
if self.calls["close"] != 0: | |
raise psycopg2.OperationalError | |
def close(self): | |
self.calls["close"] += 1 | |
if self.close_error: | |
raise psycopg2.OperationalError | |
connections = 0 | |
def connect(self, **settings): | |
self.connections += 1 | |
return self.connection(**settings) | |
OperationalError = psycopg2.OperationalError | |
extras = psycopg2.extras | |
class TestPostgreSQL(object): | |
def setup(self): | |
assert utils.psycopg2 is psycopg2 | |
self.fakes = utils.psycopg2 = FakePsycopg2() | |
self.app = flask.Flask(__name__) | |
self.app.config["POSTGRES"] = {"database": "mydb", "user": "steve"} | |
self.postgres = utils.PostgreSQL(self.app) | |
def teardown(self): | |
assert isinstance(utils.psycopg2, FakePsycopg2) | |
utils.psycopg2 = psycopg2 | |
def test_adds_hooks(self): | |
assert self.app.after_request_funcs == {None: [self.postgres.commit]} | |
assert self.app.teardown_appcontext_funcs == [self.postgres.teardown] | |
def test_connect_new(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
assert isinstance(c, self.fakes.connection) | |
assert c.settings == \ | |
{"database": "mydb", "user": "steve", | |
"connection_factory": utils.PostgreSQLConnection} | |
assert c.calls == {"commit": 0, "reset": 0, "close": 0} | |
assert c.autocommit is False | |
def test_connect_once(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
d = self.postgres.connection | |
assert c is d | |
assert self.fakes.connections == 1 | |
assert c.calls == {"commit": 0, "reset": 0, "close": 0} | |
def test_teardown_resets_before_store(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} | |
def test_connect_from_pool(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} | |
with self.app.test_request_context("/"): | |
d = self.postgres.connection | |
assert d is c | |
assert d.calls == {"commit": 0, "reset": 2, "close": 0} | |
assert d.calls == {"commit": 0, "reset": 3, "close": 0} | |
def test_removes_from_pool(self): | |
# put a connection in the pool | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
# now get two connections from the pool | |
# must explicitly create two app contexts. | |
# in normal usage, flask promises to never share an app context | |
# between requests. When testing, it will only create an app context | |
# when test_request_context is __enter__'d and there is no existing | |
# app context | |
with self.app.app_context(), self.app.test_request_context("/1"): | |
d = self.postgres.connection | |
with self.app.app_context(), self.app.test_request_context("/2"): | |
e = self.postgres.connection | |
assert d is c | |
assert e is not d | |
assert d.calls == {"commit": 0, "reset": 3, "close": 0} | |
assert e.calls == {"commit": 0, "reset": 1, "close": 0} | |
def test_connect_from_pool_bad(self): | |
# put two distinct connections in the pool | |
with self.app.app_context(), self.app.test_request_context("/1"): | |
c = self.postgres.connection | |
with self.app.app_context(), self.app.test_request_context("/2"): | |
d = self.postgres.connection | |
assert c is not d | |
assert c.calls == d.calls == {"commit": 0, "reset": 1, "close": 0} | |
c.close() | |
with self.app.test_request_context("/"): | |
e = self.postgres.connection | |
# it should try c.reset, which will fail, and then destroy the | |
# pool by closing d as well | |
# one close call by uut, one close call from above | |
assert c.calls == {"commit": 0, "reset": 2, "close": 1 + 1} | |
assert d.calls == {"commit": 0, "reset": 1, "close": 1} | |
# e should be a new connection | |
assert e is not c and e is not d | |
assert e.calls == {"commit": 0, "reset": 0, "close": 0} | |
assert e.calls == {"commit": 0, "reset": 1, "close": 0} | |
def test_absorbs_close_errors(self): | |
with self.app.app_context(), self.app.test_request_context("/1"): | |
c = self.postgres.connection | |
with self.app.app_context(), self.app.test_request_context("/2"): | |
d = self.postgres.connection | |
c.close() | |
d.close_error = True | |
with self.app.test_request_context("/"): | |
e = self.postgres.connection | |
def test_teardown_closes_if_pool_full(self): | |
# default pool size is 2 | |
with self.app.app_context(), self.app.test_request_context("/1"): | |
c = self.postgres.connection | |
with self.app.app_context(), self.app.test_request_context("/2"): | |
d = self.postgres.connection | |
with self.app.app_context(), \ | |
self.app.test_request_context("/3"): | |
e = self.postgres.connection | |
assert len(set([c, d, e])) == 3 | |
assert c.calls == {"commit": 0, "reset": 0, "close": 1} | |
assert d.calls == {"commit": 0, "reset": 1, "close": 0} | |
assert e.calls == {"commit": 0, "reset": 1, "close": 0} | |
with self.app.app_context(), self.app.test_request_context("/1"): | |
f = self.postgres.connection | |
with self.app.app_context(), self.app.test_request_context("/2"): | |
g = self.postgres.connection | |
with self.app.app_context(), \ | |
self.app.test_request_context("/3"): | |
h = self.postgres.connection | |
assert f is d | |
assert g is e | |
assert len(set([d, e, h])) == 3 | |
def test_cursor(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
x = self.postgres.cursor() | |
assert isinstance(x, c._cursor) | |
assert x.connection is c | |
assert len(c.cursors) == 1 | |
y = self.postgres.cursor() | |
assert y.connection is c | |
assert len(c.cursors) == 2 | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} | |
with self.app.test_request_context("/"): | |
# cursor first without asking for connection explicitly | |
x = self.postgres.cursor() | |
assert isinstance(x, utils.psycopg2.connection._cursor) | |
c = x.connection | |
assert isinstance(c, utils.psycopg2.connection) | |
assert c.calls == {"commit": 0, "reset": 2, "close": 0} | |
assert c.settings == \ | |
{"database": "mydb", "user": "steve", | |
"connection_factory": utils.PostgreSQLConnection} | |
assert c is self.postgres.connection | |
def test_dict_cursor(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.cursor(True) | |
assert len(self.postgres.connection.cursors) == 1 | |
assert c.real_dict_cursor | |
c = self.postgres.cursor() | |
assert not c.real_dict_cursor | |
def test_commit(self): | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
self.postgres.commit() | |
assert c.calls == {"commit": 1, "reset": 0, "close": 0} | |
def test_commit_as_hook(self): | |
# as an after request hook, commit must return the response object | |
# it is passed | |
response = object() | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
assert self.postgres.commit(response) is response | |
# now check it works as a hook | |
with self.app.test_request_context("/"): | |
d = self.postgres.connection | |
assert d is c | |
assert self.app.process_response(response) is response | |
assert c.calls == {"commit": 2, "reset": 3, "close": 0} | |
def test_commit_nop_if_no_connection(self): | |
with self.app.test_request_context("/"): | |
self.postgres.commit() | |
assert utils.psycopg2.connections == 0 | |
with self.app.test_request_context("/"): | |
self.app.process_response(None) | |
assert utils.psycopg2.connections == 0 | |
# should nop if teardown puts the connection in the pool | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
self.postgres.teardown(None) | |
self.postgres.commit() | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} | |
def test_teardown_nop_if_no_connection(self): | |
with self.app.test_request_context("/"): | |
self.postgres.teardown(None) | |
assert utils.psycopg2.connections == 0 | |
with self.app.test_request_context("/"): | |
c = self.postgres.connection | |
self.postgres.teardown(None) | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} | |
assert c.calls == {"commit": 0, "reset": 1, "close": 0} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment