Last active
August 7, 2021 11:16
-
-
Save blink1073/796e4eb01d43ebcec62a to your computer and use it in GitHub Desktop.
Avro Serializer/Deserializer with Numpy Support
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ast import literal_eval | |
from avro.io import DatumReader, DatumWriter, BinaryEncoder, BinaryDecoder | |
from avro.schema import Names, SchemaFromJSONData | |
import yaml | |
import numpy as np | |
class BinaryDatumWriter(object): | |
def __init__(self, schema, buf): | |
if isinstance(schema, (dict, str)): | |
schema = load_schema(schema) | |
self.schema = schema | |
self._writer = DatumWriter(schema) | |
self._encoder = BinaryEncoder(buf) | |
self.buf = buf | |
def write(self, datum): | |
for (key, value) in datum.items(): | |
if isinstance(value, np.ndarray): | |
field = self.schema.field_map[key] | |
if field.type.props['type'] in ['fixed', 'binary']: | |
datum[key] = value.tobytes() | |
else: | |
datum[key] = value.tolist() | |
self._writer.write(datum, self._encoder) | |
class BinaryDatumReader(object): | |
def __init__(self, writer_schema, buf, reader_schema=None): | |
if isinstance(writer_schema, (dict, str)): | |
writer_schema = load_schema(writer_schema) | |
if isinstance(reader_schema, (dict, str)): | |
reader_schema = load_schema(reader_schema) | |
self.schema = reader_schema or writer_schema | |
self._reader = DatumReader(writer_schema, reader_schema) | |
self._decoder = BinaryDecoder(buf) | |
self.buf = buf | |
def read(self): | |
return self._reader.read(self._decoder) | |
def unpack_datum(schema, datum): | |
if isinstance(schema, (dict, str)): | |
schema = load_schema(schema) | |
for field in schema.fields: | |
props = field.type.props | |
if props.get('logicalType', None) == 'ndarray': | |
if not props['type'] in ['fixed', 'binary']: | |
continue | |
dtype = np.dtype(props['dtype']) | |
name = field.name | |
datum[name] = np.frombuffer(datum[name], dtype=dtype) | |
if 'shape' in props: | |
try: | |
shape = literal_eval(props['shape']) | |
except SyntaxError: | |
msg = 'Could not parse shape: "%s"' % props['shape'] | |
raise SyntaxError(msg) | |
datum[name] = datum[name].reshape(shape) | |
return datum | |
def load_schema(schema): | |
if not isinstance(schema, dict): | |
schema = yaml.load(schema) | |
return SchemaFromJSONData(schema, Names()) | |
if __name__ == '__main__': | |
import time | |
import io | |
schema = {"namespace": "example.avro", | |
"type": "record", | |
"name": "User", | |
"fields": [ | |
{"name": "name", "type": "string"}, | |
{"name": "favorite_number", "type": ["int", "null"]}, | |
{"name": "favorite_color", "type": ["string", "null"]}, | |
{"name": "yo", | |
"type": {"type": "array", | |
"items": ["null", "string"]}}, | |
{"name": "fixed_big", | |
"type": {"type": "fixed", "name": "image", | |
"size": 1920 * 1024, | |
"logicalType": "ndarray", "dtype": "uint8", | |
"shape": "(1920,1024)"}}, | |
] | |
} | |
buf = io.BytesIO() | |
bdata = np.ones((1920, 1024), dtype=np.uint8) | |
writer = BinaryDatumWriter(schema, buf) | |
writer.write( | |
{"name": "Alyssa", "yo": [None, "hey"], | |
"favorite_number": 256, "fixed_big": bdata}) | |
t0 = time.time() | |
for i in range(1000): | |
writer.write({"name": "Ben", "favorite_number": 7, | |
"yo": [None], | |
"favorite_color": "red", "fixed_big": bdata}) | |
print('Write to buf:', time.time() - t0) | |
with open('test.dat', 'wb') as fid: | |
fid.write(buf.getbuffer()) | |
print('Write to disk:', time.time() - t0) | |
buf.seek(0) | |
reader = BinaryDatumReader(schema, buf) | |
d1 = reader.read() | |
for i in range(1000): | |
reader.read() | |
print('Read from buf:', time.time() - t0) | |
print(unpack_datum(schema, d1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import sqlite3 | |
try: | |
import cPickle as pickle | |
except ImportError: | |
import pickle | |
import numpy as np | |
PY3 = sys.version.startswith('3') | |
if PY3: | |
long = int | |
unicode = str | |
""" | |
NULL. The value is a NULL value. | |
INTEGER. The value is a signed integer, stored in 1, 2, 3, 4, 6, or 8 bytes depending on the magnitude of the value. | |
REAL. The value is a floating point value, stored as an 8-byte IEEE floating point number. | |
TEXT. The value is a text string, stored using the database encoding (UTF-8, UTF-16BE or UTF-16LE). | |
BLOB. The value is a blob of data, stored exactly as it was input. | |
""" | |
def get_dtype(value): | |
if value is None: | |
return 'NULL' | |
elif isinstance(value, (int, long)): | |
return 'INTEGER' | |
elif isinstance(value, float): | |
return 'REAL' | |
elif isinstance(value, bytes): | |
return 'BLOB' | |
else: | |
return 'TEXT' | |
def convert(value): | |
if value is None: | |
return | |
elif isinstance(value, (int, long, float, bytes, unicode)): | |
return value | |
else: | |
return pickle.dumps(value, protocol=-1) | |
class SqlHandler(object): | |
"""Manage nested data reading and writing from dictionaries to SQL | |
""" | |
def __init__(self, db): | |
if isinstance(db, str): | |
db = sqlite3.connect(db) | |
self.db = db | |
self.cursor = self.db.cursor() | |
self.tables = self.list_tables() | |
self.field_names = dict() | |
def create_table(self, datum, name, primary_key, | |
foreign=False): | |
"""Create a heirarchal table in a DB from a dictionary | |
*datum* dictionary | |
*name* name of the top level table | |
*primary_key* which key in the dictionary to use as primary | |
*foreign* Used internally by sub-tables to indicate linkage | |
""" | |
lines = ['create table %s\n(%s %s primary key' % | |
(name, primary_key, get_dtype(primary_key))] | |
for (key, value) in datum.items(): | |
if key == primary_key: | |
continue | |
if not isinstance(value, dict): | |
lines.append('%s %s' % (key, get_dtype(value))) | |
if foreign: | |
lines.append('foreign key (%s) references %s(%s)' | |
% (primary_key, name.split('__')[0], primary_key)) | |
if not name in self.tables: | |
self.cursor.execute(',\n'.join(lines) + ')') | |
self.tables.append(name) | |
self.field_names[name] = self.list_fields(name) | |
for (key, value) in datum.items(): | |
if isinstance(value, dict): | |
value[primary_key] = datum[primary_key] | |
self.create_table(value, '%s__%s' % (name, key), | |
primary_key, foreign=True) | |
def remove_table(self, name, cmd='drop'): | |
"""Remove a table and its children""" | |
self.cursor.execute('%s table %s' % (cmd, name)) | |
for table in self.list_tables(): | |
if table.startswith(name + '__'): | |
self.cursor.execute('%s table %s' % (cmd, table)) | |
def clear_table(self, name): | |
"""Clear a table and its children""" | |
self.remove_table(name, cmd='truncate') | |
def add_data(self, datum, tbl_name): | |
"""Add datum to a given table heirarchally. | |
*datum* a dictionary | |
*tbl_name* name of a valid, existing table | |
Assumes the data is in the right format for the table. | |
""" | |
if not tbl_name in self.field_names: | |
field_names = self.list_fields(tbl_name)[0] | |
self.field_names[tbl_name] = field_names | |
primary_key = self.field_names[tbl_name][0] | |
keys = list(datum.keys()) | |
for (key, value) in datum.items(): | |
if isinstance(value, dict): | |
keys.remove(key) | |
query = "insert into {0} ({1}) values (?{2})" | |
query = query.format(tbl_name, ",".join(keys), ",?" * (len(keys) - 1)) | |
self.cursor.execute(query, [convert(datum[k]) for k in keys]) | |
for (key, value) in datum.items(): | |
if isinstance(value, dict): | |
value[primary_key] = datum[primary_key] | |
self.add_data(value, '%s__%s' % (tbl_name, key)) | |
def read_table(self, tbl_name, which='*', criteria=None): | |
"""Read data from a heirarchal table as a dictionary. | |
*tbl_name* is the top level table in the heirarchy | |
*which* can be a list of field names or a csv string of names | |
*criteria* is a valid SQL WHERE statement | |
Notes: | |
*tbl_name* and *which* can be a path to nested data | |
in /path/format. | |
*criteria* can include AND or OR and use any valid Comparision | |
or Logical Operators like >, <, =, LIKE, NOT, etc. | |
""" | |
if tbl_name.startswith('/'): | |
tbl_name = tbl_name[1:] | |
tbl_name = tbl_name.replace('/', '__') | |
if not tbl_name in self.field_names: | |
self.field_names[tbl_name] = self.list_fields(tbl_name) | |
fields = self.field_names[tbl_name] | |
if isinstance(which, list): | |
which = ','.join(which) | |
which = which.split(',') | |
if which == ['*']: | |
qwhich = fields | |
else: | |
qwhich = [w for w in which if w in fields] | |
if qwhich: | |
query = 'SELECT %s from %s' % (','.join(qwhich), tbl_name) | |
if criteria: | |
query += ' where %s' % criteria.replace(' ', '') | |
self.cursor.execute(query) | |
values = self.cursor.fetchall() | |
data = [dict(zip(qwhich, v)) for v in values] | |
else: | |
data = None | |
for table in self.tables: | |
if (table.startswith(tbl_name + '__') | |
and table.count('__') == tbl_name.count('__') + 1): | |
data = self._read_subtable(table, which, criteria, data) | |
return data | |
def _read_subtable(self, table, which, criteria, data): | |
"""Retreive data from a child table""" | |
key = table.rpartition('__')[-1] | |
sub_which = [] | |
for s in which: | |
if s.startswith('/'): | |
s = s[1:] | |
if '/' in s: | |
s = s.partition('/')[2] | |
sub_which.append(s) | |
if key in sub_which: | |
sub_data = self.read_table(table, '*', criteria) | |
else: | |
sub_data = self.read_table(table, sub_which, criteria) | |
if not sub_data: | |
return data | |
if not data: | |
data = [{key: d} for d in sub_data] | |
else: | |
for (datum, sub_datum) in zip(data, sub_data): | |
datum[key] = sub_datum | |
return data | |
def list_tables(self): | |
"""Get a list of the table names in the DB""" | |
query = 'SELECT name FROM sqlite_master WHERE type = "table"' | |
self.cursor.execute(query) | |
return [t[0] for t in self.cursor.fetchall()] | |
def list_fields(self, tbl_name): | |
"""Get a list of the field names for a given table name""" | |
self.cursor.execute("pragma table_info(%s)" % tbl_name) | |
return [r[1] for r in self.cursor.fetchall()] | |
def unpack(datum): | |
for (key, value) in datum.items(): | |
if isinstance(value, dict): | |
datum[key] = unpack(value) | |
elif isinstance(value, bytes): | |
try: | |
datum[key] = pickle.loads(value) | |
except Exception: | |
pass | |
return datum | |
if __name__ == '__main__': | |
import os | |
if os.path.exists('test.sqlite'): | |
os.remove('test.sqlite') | |
db = sqlite3.connect('test.sqlite') | |
bdata = np.ones((10, 10), dtype=np.uint8) | |
data = {"name": "Alyssa", "yo": [None, "hey"], | |
"favorite_number": 256, "fixed_big": bdata, | |
'inner': dict(b=10, inner2=dict(foo=3))} | |
sh = SqlHandler(db) | |
sh.create_table(data, 'test2', 'name') | |
sh.add_data(data, 'test2') | |
data['name'] = 'Bob' | |
data['favorite_number'] = 'forty' | |
sh.add_data(data, 'test2') | |
print(sh.read_table('test2')) | |
print(sh.read_table('test2', which='favorite_number')) | |
print(sh.read_table('test2', which='inner')) | |
print(sh.read_table('test2', which='/inner/b')) | |
print(sh.read_table('test2', which='/inner/inner2')) | |
print(sh.read_table('test2', which='inner/inner2/foo')) | |
bob = sh.read_table('test2', criteria='name="Bob"')[0] | |
print('Bob:', unpack(bob)) | |
print(sh.read_table('test2/inner')) | |
print(sh.list_fields('test2')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment