Skip to content

Instantly share code, notes, and snippets.

@evansd
Created December 13, 2023 14:43
Show Gist options
  • Save evansd/4a0dc4018648a8a9906be02f8f434d90 to your computer and use it in GitHub Desktop.
Save evansd/4a0dc4018648a8a9906be02f8f434d90 to your computer and use it in GitHub Desktop.
import pickle
import platform
import struct
import timeit
import warnings
import numpy
import pyarrow
from scipy.sparse import csc_matrix
warnings.filterwarnings(
"ignore", message="'pyarrow.SerializationContext' is deprecated", module="."
)
warnings.filterwarnings(
"ignore", message="'pyarrow.serialize' is deprecated", module="."
)
warnings.filterwarnings(
"ignore", message="'pyarrow.deserialize' is deprecated", module="."
)
context = pyarrow.SerializationContext()
def serialize_csc(matrix):
"""
Decompose a matrix in Compressed Sparse Column format into more basic data
types (tuples and numpy arrays) which PyArrow knows how to serialize
"""
return ((matrix.data, matrix.indices, matrix.indptr), matrix.shape)
def deserialize_csc(args):
"""
Reconstruct a Compressed Sparse Column matrix from its decomposed parts
"""
# We construct a `csc_matrix` instance by directly assigning its members,
# rather than using `__init__` which runs additional checks that
# significantly slow down deserialization. Because we know these values
# came from properly constructed matrices we can skip these checks
(data, indices, indptr), shape = args
matrix = csc_matrix.__new__(csc_matrix)
matrix.data = data
matrix.indices = indices
matrix.indptr = indptr
matrix._shape = shape
return matrix
# Register a custom PyArrow serialization context which knows how to handle
# Compressed Sparse Column (csc) matrices
context.register_type(
csc_matrix,
"csc",
custom_serializer=serialize_csc,
custom_deserializer=deserialize_csc,
)
def random_array(seed=123):
numpy.random.seed(seed)
matrix = csc_matrix(numpy.random.rand(12500, 60))
return matrix
def pyarrow_setup():
return pyarrow_serialize(random_array())
def pyarrow_serialize(obj):
return context.serialize(obj).to_buffer().to_pybytes()
def pyarrow_read(data):
return context.deserialize(data)
def pickle_setup():
return pickle_serialize(random_array())
def pickle_serialize(obj):
return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
def pickle_read(data):
return pickle.loads(data)
def pickle_zc_setup():
return pickle_zc_serialize(random_array())
def pickle_zc_serialize(obj):
buffers = []
pickled = pickle.dumps(
obj,
protocol=5,
buffer_callback=lambda buffer: buffers.append(buffer.raw()),
)
buffers.append(pickled)
return serialize_buffers(buffers)
def pickle_zc_read(data):
buffers = deserialize_buffers(data)
return pickle.loads(buffers[-1], buffers=buffers)
def serialize_buffers(buffers):
sizes = [len(buffer) for buffer in buffers]
header = serialize_ints(sizes)
return b"".join([header, *buffers])
def deserialize_buffers(data):
data = memoryview(data)
sizes, offset = deserialize_ints(data)
output = []
for size in sizes:
next_offset = offset + size
output.append(data[offset:next_offset])
offset = next_offset
return output
def serialize_ints(ints):
count = len(ints)
return struct.pack(f"<{count + 1}I", count, *ints)
def deserialize_ints(data):
count = struct.unpack("<I", data[:4])[0]
end = 4 + (count * 4)
return struct.unpack(f"<{count}I", data[4:end]), end
if __name__ == "__main__":
print(
"Python {}, Pickle {}".format(
platform.python_version(), pickle.HIGHEST_PROTOCOL
)
)
for fn_name in ["pickle", "pickle_zc", "pyarrow"]:
serialize_fn = locals()[f"{fn_name}_serialize"]
read_fn = locals()[f"{fn_name}_read"]
obj = random_array()
assert numpy.array_equal(read_fn(serialize_fn(obj)).todense(), obj.todense())
loops = 5000
results = timeit.repeat(
setup="from __main__ import {0}_setup, {0}_read; data = {0}_setup()".format(
fn_name
),
stmt="n = {0}_read(data)".format(fn_name),
number=loops,
)
best = min(results)
usec = best * 1e6 / loops
print(
"{} deserialize: {} loops, best of 3: {:.1f} usec per loop".format(
fn_name, loops, usec
)
)
results = timeit.repeat(
setup="from __main__ import random_array, {0}_serialize; data = random_array()".format(
fn_name
),
stmt="n = {0}_serialize(data)".format(fn_name),
number=loops,
)
best = min(results)
usec = best * 1e6 / loops
print(
"{} serialize: {} loops, best of 3: {:.1f} usec per loop".format(
fn_name, loops, usec
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment