Created
October 4, 2018 04:55
-
-
Save theY4Kman/c608945242cca819f298ccf2ce63f522 to your computer and use it in GitHub Desktop.
A Django ORM Expression for Postgres ARRAY literals
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.contrib.postgres.fields import ArrayField | |
from django.db.models import Expression, Value | |
from netfields import InetAddressField | |
class Array(Expression): | |
"""A Postgres ARRAY[] expression""" | |
def __init__(self, *items, output_field=None): | |
if output_field is None: | |
raise ValueError('Please specify output_field as the type of the items') | |
self.items = tuple( | |
Value(item) if isinstance(item, str) else item | |
for item in items | |
) | |
super().__init__(output_field=ArrayField(output_field)) | |
def __repr__(self): | |
return f'<{self.__class__.__name__}: {self}>' | |
def __str__(self): | |
return f'{{{", ".join(repr(o) for o in self.items)}}}' | |
def get_source_expressions(self): | |
return self.items | |
def set_source_expressions(self, exprs): | |
self.items = exprs | |
def as_sql(self, compiler, connection): | |
expressions = [] | |
expression_params = [] | |
for expression in self.items: | |
sql, params = compiler.compile(expression) | |
expressions.append(sql) | |
expression_params.extend(params) | |
expression_wrapper = '(ARRAY[%s])' | |
sql = connection.ops.combine_expression(',', expressions) | |
return expression_wrapper % sql, expression_params | |
class InetAddressArray(Array): | |
def __init__(self, *ips): | |
super(InetAddressArray, self).__init__(*ips, output_field=InetAddressField()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment