Skip to content

Instantly share code, notes, and snippets.

@mzpqnxow
Created February 9, 2021 14:35
Show Gist options
  • Save mzpqnxow/880104fd02565844f227c2aed1ae3286 to your computer and use it in GitHub Desktop.
Save mzpqnxow/880104fd02565844f227c2aed1ae3286 to your computer and use it in GitHub Desktop.
PostgreSQL with SQLALchemy: INET/CIDR masklen function
#!/usr/bin/env python3
"""Small function to use PostgreSQL MASKLEN() with SQLALchemy
If you work regularly with SQLAlchemy and know a cleaner way to do this, please
leave a comment on this Gist! Thanks
(C) 2019, mzpqnxow, BSD 3-Clause
"""
from typing import Union
from sqlalchemy import Column, and_, tuple_
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.expression import literal
from mymodels import MyTable
import operator
def masklen(value: int, operator_: str, col: Union[Column, InstrumentedAttribute]) -> BinaryExpression:
"""Wrapper for using MASKLEN() in queries on PostgreSQL INET/CIDR queries
A huge benefit of PostgreSQL for certain use-cases is the native datatypes it supports
for IPv4 and IPv6 networks- INET and CIDR. The two types are identical except CIDR does
stricter checking of the mask to make sure none of the network bits are set. This is
similar to Python ipaddress.ip_network() when using strict=True. So this should work fine
on both CIDR and INET columns
One of the features provided by PostgreSQL is the ability to query using the prefix length
or `MASKLEN` in PostgreSQL terms. It doesn't seem to be exposed directly in SQLAlchemy, so
this is a small wrapper to make it more readable
CREATE TABLE MyTable (
network INET,
active BOOL,
...);
Usage
=====
>>> query = session.query(MyTable).filter(and_(MyTable.active == True, masklen(24, '=', MyTable.network)))
Notes
=====
I don't use ORMs very often, so I'm not too sure about the "correct" way to implement
this. Ideally it would be integrated more into SQLAlchemy. But this method is a lot
better than typing raw SQL into a query
Security
========
The type checking on the mask length value and the operator literal should make
this safe from any sort of injection if you're using it from a web application or
some other context where there may be hostile user input
- AG
"""
operator_set = {'=', '<', '>', '<=', '>=', '!='}
if operator_ not in operator_set:
raise ValueError('Invalid operator, must be one of {}'.format(', '.join(operator_set)))
if not isinstance(value, int):
raise TypeError('Expected integer for mask length, got {}'.format(type(value))
expr = literal(value).op(operator_ + ' MASKLEN')(tuple_(col))
return expr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment