Created
February 28, 2022 11:26
-
-
Save jerch/fd0fae0107ce7b153b7540111b2e89ab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from io import StringIO | |
from decimal import Decimal | |
# TODO: | |
# - investigate: late column casts better than early? (make temp table all text?) | |
# - encoder/decoder for all django field types | |
# - sanitize method interfaces of CopyConverter / more django-like | |
# - Do we need a temp file shim instead of StringIO for very big data? | |
# - Better with ByteIO instead of StringIO? | |
# - parse_copydata as generator | |
# - temp model abstraction needed? | |
TXT_REPL = { | |
'\\\\': '\\', | |
'\\b': '\b', | |
'\\f': '\f', | |
'\\n': '\n', | |
'\\r': '\r', | |
'\\t': '\t', | |
'\\v': '\v' | |
} | |
REX_DECODE_TEXT = re.compile(r'\\\\|\\[bfnrtv]') | |
BOOLEAN_REPL = { | |
't': True, | |
'f': False | |
} | |
ENCODERS = { | |
'bytea': lambda v: '\\\\x' + v.hex(), | |
'text': lambda v: (v.replace('\\', '\\\\') | |
.replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') | |
.replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')), | |
'int': lambda v: str(v), | |
'decimal': lambda v: str(v), | |
'float': lambda v: str(float(v)), | |
'boolean': lambda v: 't' if v else 'f', | |
} | |
DECODERS = { | |
'bytea': lambda v: bytes.fromhex(v[3:]), | |
'text': lambda v: REX_DECODE_TEXT.sub(lambda m: TXT_REPL[m.string[m.start():m.end()]], v), | |
'int': lambda v: int(v), | |
'decimal': lambda v: Decimal(v), | |
'float': lambda v: float(v), | |
'boolean': lambda v: BOOLEAN_REPL[v], | |
} | |
class CopyConverter: | |
def __init__(self, sep='\t', null='\\N'): | |
self.sep = sep | |
self.null = null | |
def encode(self, v, typename): | |
if v is None: | |
return self.null | |
return ENCODERS[typename](v) | |
def decode(self, v, typename): | |
if v == self.null: | |
return None | |
return DECODERS[typename](v) | |
def create_copydata(self, f, column_types, data, fields=None): | |
tmpl = self.sep.join(['{}'] * len(column_types)) | |
if fields is None: | |
for o in data: | |
line = [self.encode(dp, column_types[i]) for i, dp in enumerate(o)] | |
f.write(tmpl.format(*line) + '\n') | |
else: | |
for o in data: | |
line = [self.encode(getattr(o, fname), column_types[i]) for i, fname in enumerate(fields)] | |
f.write(tmpl.format(*line) + '\n') | |
f.seek(0) | |
def copy_from(self, table, columns, cursor, data, column_types, fields=None): | |
f = StringIO() | |
self.create_copydata(f, column_types, data, fields) | |
cursor.copy_from(f, table, sep=self.sep, null=self.null, size=8192, columns=columns) | |
def parse_copydata(self, f, column_types): | |
parsed = [] | |
f.seek(0) | |
for line in f: | |
line = line.rstrip('\n') | |
parsed.append( | |
tuple( | |
self.decode(col, column_types[i]) | |
for i, col in enumerate(line.split(self.sep)) | |
) | |
) | |
return parsed | |
def copy_to(self, table, columns, cursor, decoders): | |
f = StringIO() | |
cursor.copy_to(f, table, sep=self.sep, null=self.null, columns=columns) | |
return self.parse_copydata(f, decoders) | |
def copy_update(qs, objs, fieldnames, batch_size: int = 1000): | |
model = qs.model | |
# filter all non model local fields --> still handled by bulk_update | |
non_local_fieldnames = [] | |
local_fieldnames = [] | |
for f in fieldnames: | |
if model._meta.get_field(f) not in model._meta.local_fields: | |
non_local_fieldnames.append(f) | |
else: | |
local_fieldnames.append(f) | |
# avoid more expensive doubled updates | |
if non_local_fieldnames and len(local_fieldnames) < 2: | |
return model.objects.bulk_update(objs, fieldnames, batch_size) | |
if local_fieldnames: | |
from django.db import connections | |
tablename = model._meta.db_table | |
pk_field = model._meta.pk | |
if not pk_field: | |
return model.objects.bulk_update(objs, fieldnames, batch_size) | |
fields = [model._meta.get_field(f) for f in local_fieldnames] | |
connection = connections[qs.db] | |
with connection.cursor() as cur: | |
cur.execute('CREATE TEMPORARY TABLE my_temp (pk integer UNIQUE, f1 integer)') | |
cc = CopyConverter() | |
cc.copy_from('my_temp', ('pk', 'f1'), cur, objs, ('int', 'int'), ['pk'] + fieldnames) | |
cur.execute('CREATE INDEX some_idx_on_temp ON my_temp (pk)') | |
cur.execute(update_from_table(tablename, pk_field.column, fields)) | |
cur.execute('DROP TABLE my_temp') | |
# TODO: apply left overs | |
def update_from_table(tname, pkname, fields): | |
cols = ','.join(f'"{f.column}"=my_temp.{f.column}' for f in fields) | |
where = f'"{tname}"."{pkname}"=my_temp.pk' | |
return f'UPDATE "{tname}" SET {cols} FROM my_temp WHERE {where}' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If someone is interested, I've started packaging things more seriously in https://github.com/netzkolchose/django-fast-update. Still lacks all sorts of tests, from manual testing
fast_update
should mostly work though.I also added a
copy_update
draft from my playground tests, but thats only half done (still misses several encoders and array support, have no priority on this, as psycopg3 will make it obsolete soon).