Created
December 13, 2023 14:43
-
-
Save evansd/4a0dc4018648a8a9906be02f8f434d90 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import platform | |
import struct | |
import timeit | |
import warnings | |
import numpy | |
import pyarrow | |
from scipy.sparse import csc_matrix | |
warnings.filterwarnings( | |
"ignore", message="'pyarrow.SerializationContext' is deprecated", module="." | |
) | |
warnings.filterwarnings( | |
"ignore", message="'pyarrow.serialize' is deprecated", module="." | |
) | |
warnings.filterwarnings( | |
"ignore", message="'pyarrow.deserialize' is deprecated", module="." | |
) | |
context = pyarrow.SerializationContext() | |
def serialize_csc(matrix): | |
""" | |
Decompose a matrix in Compressed Sparse Column format into more basic data | |
types (tuples and numpy arrays) which PyArrow knows how to serialize | |
""" | |
return ((matrix.data, matrix.indices, matrix.indptr), matrix.shape) | |
def deserialize_csc(args): | |
""" | |
Reconstruct a Compressed Sparse Column matrix from its decomposed parts | |
""" | |
# We construct a `csc_matrix` instance by directly assigning its members, | |
# rather than using `__init__` which runs additional checks that | |
# significantly slow down deserialization. Because we know these values | |
# came from properly constructed matrices we can skip these checks | |
(data, indices, indptr), shape = args | |
matrix = csc_matrix.__new__(csc_matrix) | |
matrix.data = data | |
matrix.indices = indices | |
matrix.indptr = indptr | |
matrix._shape = shape | |
return matrix | |
# Register a custom PyArrow serialization context which knows how to handle | |
# Compressed Sparse Column (csc) matrices | |
context.register_type( | |
csc_matrix, | |
"csc", | |
custom_serializer=serialize_csc, | |
custom_deserializer=deserialize_csc, | |
) | |
def random_array(seed=123): | |
numpy.random.seed(seed) | |
matrix = csc_matrix(numpy.random.rand(12500, 60)) | |
return matrix | |
def pyarrow_setup(): | |
return pyarrow_serialize(random_array()) | |
def pyarrow_serialize(obj): | |
return context.serialize(obj).to_buffer().to_pybytes() | |
def pyarrow_read(data): | |
return context.deserialize(data) | |
def pickle_setup(): | |
return pickle_serialize(random_array()) | |
def pickle_serialize(obj): | |
return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) | |
def pickle_read(data): | |
return pickle.loads(data) | |
def pickle_zc_setup(): | |
return pickle_zc_serialize(random_array()) | |
def pickle_zc_serialize(obj): | |
buffers = [] | |
pickled = pickle.dumps( | |
obj, | |
protocol=5, | |
buffer_callback=lambda buffer: buffers.append(buffer.raw()), | |
) | |
buffers.append(pickled) | |
return serialize_buffers(buffers) | |
def pickle_zc_read(data): | |
buffers = deserialize_buffers(data) | |
return pickle.loads(buffers[-1], buffers=buffers) | |
def serialize_buffers(buffers): | |
sizes = [len(buffer) for buffer in buffers] | |
header = serialize_ints(sizes) | |
return b"".join([header, *buffers]) | |
def deserialize_buffers(data): | |
data = memoryview(data) | |
sizes, offset = deserialize_ints(data) | |
output = [] | |
for size in sizes: | |
next_offset = offset + size | |
output.append(data[offset:next_offset]) | |
offset = next_offset | |
return output | |
def serialize_ints(ints): | |
count = len(ints) | |
return struct.pack(f"<{count + 1}I", count, *ints) | |
def deserialize_ints(data): | |
count = struct.unpack("<I", data[:4])[0] | |
end = 4 + (count * 4) | |
return struct.unpack(f"<{count}I", data[4:end]), end | |
if __name__ == "__main__": | |
print( | |
"Python {}, Pickle {}".format( | |
platform.python_version(), pickle.HIGHEST_PROTOCOL | |
) | |
) | |
for fn_name in ["pickle", "pickle_zc", "pyarrow"]: | |
serialize_fn = locals()[f"{fn_name}_serialize"] | |
read_fn = locals()[f"{fn_name}_read"] | |
obj = random_array() | |
assert numpy.array_equal(read_fn(serialize_fn(obj)).todense(), obj.todense()) | |
loops = 5000 | |
results = timeit.repeat( | |
setup="from __main__ import {0}_setup, {0}_read; data = {0}_setup()".format( | |
fn_name | |
), | |
stmt="n = {0}_read(data)".format(fn_name), | |
number=loops, | |
) | |
best = min(results) | |
usec = best * 1e6 / loops | |
print( | |
"{} deserialize: {} loops, best of 3: {:.1f} usec per loop".format( | |
fn_name, loops, usec | |
) | |
) | |
results = timeit.repeat( | |
setup="from __main__ import random_array, {0}_serialize; data = random_array()".format( | |
fn_name | |
), | |
stmt="n = {0}_serialize(data)".format(fn_name), | |
number=loops, | |
) | |
best = min(results) | |
usec = best * 1e6 / loops | |
print( | |
"{} serialize: {} loops, best of 3: {:.1f} usec per loop".format( | |
fn_name, loops, usec | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment