Created
August 21, 2020 13:47
-
-
Save Daenyth/ac970990f6410045282aaa507a58a643 to your computer and use it in GitHub Desktop.
Pytest postgres fixtures
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in myapp/ | |
import os | |
from pathlib import Path | |
from typing import Mapping, NamedTuple, Optional | |
from datetime import timedelta | |
class DbConfig(NamedTuple): | |
dbname: str | |
user: str | |
password: str | |
host: str | |
port: int | |
schema: str | |
@property | |
def dsn(self) -> str: | |
password = "password=" + self.password if self.password else '' | |
return "dbname={self.dbname} user={self.user} {password} host={self.host} port={self.port}".format( | |
self=self, password=password) | |
@property | |
def postgres_url(self) -> str: | |
return "postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}".format(self=self) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in tests/ | |
import asyncio | |
import aiopg | |
import pytest | |
from aiopg import Connection, Cursor, Pool | |
from psycopg2.extras import NamedTupleCursor | |
from myapp.config import DbConfig | |
from myapp.db.schema import create_full_schema | |
from myapp.db.util import set_search_path | |
from .postgres_fixture import PostgresServer | |
@pytest.fixture(scope='session') | |
def postgres_server_sess(request): | |
server = PostgresServer() | |
server.start() | |
request.addfinalizer(server.teardown) | |
return server | |
@pytest.fixture(scope='session') | |
def db_config_sess(postgres_server_sess: PostgresServer) -> DbConfig: | |
"""Returns a DbConfig pointing at a fully-created db schema""" | |
server_cfg = postgres_server_sess.connection_config | |
schema = 'integration_sess' | |
db_cfg = DbConfig( | |
dbname=server_cfg['database'], | |
user=server_cfg['user'], | |
password='', | |
host=server_cfg['host'], | |
port=server_cfg['port'], | |
schema=schema) | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(create_full_schema(db_cfg)) | |
loop.close() | |
return db_cfg | |
@pytest.fixture(scope="function") | |
async def db_pool(db_config_sess: DbConfig) -> Pool: | |
async with aiopg.create_pool( | |
db_config_sess.dsn, timeout=1, on_connect=set_search_path(db_config_sess.schema)) as pool: | |
yield pool | |
@pytest.fixture(scope="function") | |
async def db_conn(db_pool: Pool) -> Connection: | |
"""A database connection with the test schema attached as a `schema` attribute.""" | |
async with db_pool.acquire() as conn: | |
yield conn | |
@pytest.fixture(scope="function") | |
async def db_cursor(db_conn: Connection) -> Cursor: | |
"""A database cursor set to the test schema.""" | |
async with db_conn.cursor() as cursor: | |
yield cursor | |
@pytest.fixture(scope="function") | |
async def db_named_tuple_cursor(db_conn: Connection) -> Cursor: | |
"""A database cursor set to the test schema.""" | |
async with db_conn.cursor(cursor_factory=NamedTupleCursor) as cursor: | |
yield cursor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in tests/ | |
import errno | |
import subprocess | |
import os | |
import psycopg2 | |
import pytest | |
from psycopg2 import OperationalError | |
from pytest_server_fixtures.base import TestServer | |
from six import text_type | |
class PostgresServer(TestServer): | |
""" | |
Exposes a server.connect() method returning a raw psycopg2 connection. | |
Also exposes a server.connection_config returning a dict with connection parameters | |
""" | |
random_port = True | |
def __init__(self, database_name="integration", skip_on_missing_postgres=False, **kwargs): | |
self.database_name = database_name | |
# TODO make skip configurable with a pytest flag | |
self._fail = pytest.skip if skip_on_missing_postgres else pytest.exit | |
super().__init__(workspace=None, delete=True, preserve_sys_path=False, **kwargs) | |
def kill(self, retries=5): | |
if hasattr(self, 'pid'): | |
try: | |
os.kill(self.pid, self.kill_signal) | |
except OSError as e: | |
if e.errno == errno.ESRCH: # "No such process" | |
pass | |
else: | |
raise | |
def pre_setup(self): | |
""" | |
Find postgres server binary | |
Set up connection parameters | |
""" | |
(self.workspace / 'db').mkdir() # pylint: disable=no-value-for-parameter | |
try: | |
self.pg_bin = subprocess.check_output(["pg_config", "--bindir"]).decode('utf-8').rstrip() | |
except OSError as e: | |
msg = "Failed to get pg_config --bindir: " + text_type(e) | |
print(msg) | |
self._fail(msg) | |
initdb_path = self.pg_bin + '/initdb' | |
if not os.path.exists(initdb_path): | |
msg = "Unable to find pg binary specified by pg_config: {} is not a file".format(initdb_path) | |
print(msg) | |
self._fail(msg) | |
try: | |
subprocess.check_call([initdb_path, str(self.workspace / 'db')]) | |
except OSError as e: | |
msg = "Failed to launch postgres: " + text_type(e) | |
print(msg) | |
self._fail(msg) | |
@property | |
def connection_config(self): | |
return { | |
u'host': u'localhost', | |
u'user': os.environ[u'USER'], | |
u'port': self.port, | |
u'database': self.database_name | |
} | |
@property | |
def run_cmd(self): | |
cmd = [ | |
self.pg_bin + '/postgres', | |
'-F', | |
'-k', str(self.workspace / 'db'), | |
'-D', str(self.workspace / 'db'), | |
'-p', str(self.port), | |
'-c', "log_min_messages=FATAL" | |
] # yapf: disable | |
return cmd | |
def check_server_up(self): | |
try: | |
print("Connecting to Postgres at localhost:{}".format(self.port)) | |
with self.connect('postgres') as conn: | |
conn.set_session(autocommit=True) | |
with conn.cursor() as cursor: | |
cursor.execute("CREATE DATABASE " + self.database_name) | |
self.connection = self.connect(self.database_name) | |
with open(self.workspace / 'db' / 'postmaster.pid', 'r') as f: | |
self.pid = int(f.readline().rstrip()) | |
return True | |
except OperationalError as e: | |
print("Could not connect to test postgres: {}".format(e)) | |
return False | |
def connect(self, database=None): | |
cfg = self.connection_config | |
if database is not None: | |
cfg[u'database'] = database | |
return psycopg2.connect(**cfg) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in myapp/db/ | |
import sqlalchemy as sa # type: ignore | |
from sqlalchemy import Column, DDL, Index, Table | |
from sqlalchemy.dialects.postgresql import JSONB # type: ignore | |
from sqlalchemy.sql.schema import MetaData # type: ignore | |
from myapp.config import DbConfig | |
def EventTable(metadata: MetaData) -> Table: | |
return Table("event", metadata, | |
Column("sequence_id", sa.BIGINT, autoincrement=True, primary_key=True), | |
Column("instance_id", sa.TEXT, nullable=False), | |
Column("aggregate_id", sa.TEXT, nullable=False), | |
Column("created", sa.TIMESTAMP(timezone=True), nullable=False), | |
Column("type", sa.TEXT, nullable=False), | |
Column("type_version", sa.INT, nullable=False), | |
Column("occurred_at", sa.TIMESTAMP(timezone=True), nullable=False), | |
Column("data", JSONB, nullable=False), | |
Column("data_md5", sa.TEXT, nullable=False), | |
Column("synced", sa.BOOLEAN, nullable=False, default=False), | |
Index("event_aggregate_id_data_md5", "instance_id", "aggregate_id", "data_md5", unique=True), | |
Index("event_aggregate_id_instance_data_type", "instance_id", "aggregate_id", "type")) | |
def AggregateSnapshotTable(metadata: MetaData) -> Table: | |
return Table( | |
"aggregate_snapshot", | |
metadata, | |
Column("sequence_id", sa.BIGINT, autoincrement=True, primary_key=True), | |
Column("instance_id", sa.TEXT, nullable=False), | |
Column("aggregate_id", sa.TEXT, nullable=False), | |
Column("aggregate_type", sa.TEXT, nullable=False), | |
Column("aggregate_type_version", sa.INT, nullable=False), | |
Column("aggregate_version", sa.INT, nullable=False), | |
Column("created", sa.TIMESTAMP(timezone=True), nullable=False, default='now()'), | |
Column("data", JSONB, nullable=False), | |
Index("snapshot_aggregate_id", "instance_id", 'aggregate_id', 'aggregate_type', unique=True), | |
Index("snapshot_related_aggregate_id", "instance_id", "related_aggregate_id", "aggregate_id")) | |
def event_trigger(schema_name: str) -> DDL: | |
"""Get the DDL for insert trigger and notify channel""" | |
return DDL(""" | |
CREATE OR REPLACE FUNCTION notify_event() RETURNS trigger AS $$ | |
DECLARE | |
BEGIN | |
NOTIFY event; | |
RETURN NEW; | |
END; | |
$$ LANGUAGE plpgsql; | |
DROP TRIGGER IF EXISTS notify_event_on_event_insert ON {schema}.event; | |
CREATE TRIGGER notify_event_on_event_insert | |
AFTER INSERT ON {schema}.event | |
EXECUTE PROCEDURE notify_event(); | |
""".format(schema=schema_name)) | |
async def create_full_schema(db: DbConfig) -> None: | |
""" | |
Creates the full schema from scratch. | |
TODO make idempotent. | |
""" | |
pg_url = db.postgres_url | |
db_schema = db.schema | |
metadata = sa.MetaData(schema=db_schema) | |
event_tbl = EventTable(metadata) | |
snap_tbl = AggregateSnapshotTable(metadata) | |
engine = sa.create_engine(pg_url) | |
engine.execute("CREATE SCHEMA IF NOT EXISTS " + db_schema) | |
metadata.create_all(engine, tables=[event_tbl, snap_tbl], checkfirst=True) | |
engine.execute(event_trigger(db_schema)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://pypi.org/project/pytest-server-fixtures/