Skip to content

Instantly share code, notes, and snippets.

@dwf
Last active February 10, 2017 08:45
Show Gist options
  • Save dwf/5139606 to your computer and use it in GitHub Desktop.
Save dwf/5139606 to your computer and use it in GitHub Desktop.
A proof-of-concept for a multi-threaded pre-fetching loader for .NPY files.
"""
A proof-of-concept multi-threaded chunked NPY format reader.
"""
__author__ = "David Warde-Farley"
__credits__ = ["David Warde-Farley"]
__license__ = "3-clause BSD"
__email__ = 'd dot warde dot farley at gmail DOT com"
import threading
from numpy.lib.format import read_magic, read_array_header_1_0
cimport numpy as np
from libc.stdio cimport (FILE, fopen, fclose, fread, ferror,
feof, fseek, SEEK_SET)
from libc.stdlib cimport malloc, free
# For some reason this isn't in libc.stdio
cdef extern from "stdio.h" nogil:
void clearerr(FILE *fp)
cdef extern from "errno.h":
char **sys_errlist
int errno
np.import_array()
cdef size_t fetch_loop(long base_offset, char *buf,
size_t row_bytes, size_t rows, FILE *fp) nogil:
"""
Fetch chunks of an array from an open file pointer,
wrapping to `base_offset` at EOF.
Parameters
----------
base_offset : long
The offset in the file at which non-header data begins.
buf : char *
A buffer of length at least `row_bytes * rows` into which
data will be read.
rows_bytes : size_t
The number of bytes taken by a single row.
rows : size_t
The number of rows to read.
fp : FILE *
The file pointer to read from.
Notes
-----
If an EOF is encountered by a shortened read, this function
will seek to `base_offset`. Otherwise (including if there
are no more bytes to be read but `feof()` returns false)
it will leave the stream position directly after the last read.
"""
cdef size_t n_read
cdef int error
n_read = fread(buf, row_bytes, rows, fp)
if n_read < rows:
error = ferror(fp)
if error:
clearerr(fp)
with gil:
raise IOError(sys_errlist[error])
elif feof(fp):
if fseek(fp, base_offset, SEEK_SET):
with gil:
# XXX: what's the etiquette about clearing errno?
global errno, sys_errlist
raise IOError(sys_errlist[errno])
if n_read == 0:
return fetch_loop(base_offset, buf, row_bytes,
rows, fp)
else:
with gil:
raise IOError("requested %d rows, got %d, but "
"no error or EOF; this should not "
"happen" % (rows, n_read))
return n_read
cdef class SequentialBatchReader:
"""
An extension type that manages an NPY file being
read from in chunks. It pre-fetches in a separate thread.
Parameters
----------
fname : str
The location of the .npy file.
batch_rows : int
The size of the batch to request. If the total number of
elements along the first axis is not a multiple of
`batch_rows` then there will be one smaller batch
before wrap-around.
"""
# The file
cdef str fname
cdef FILE *fp
# Sizing and offset details
cdef np.npy_intp scalar_bytes
cdef np.npy_intp row_bytes
cdef np.npy_intp batch_rows
cdef long data_offset
cdef tuple shape
cdef object dtype
# The persistent buffer for pre-fetching.
cdef char *buf
cdef np.npy_intp n_read
# A field for the pre-fetch thread object.
cdef object thread
def __init__(self, fname, batch_rows):
with open(fname, 'rb') as f:
magic = read_magic(f)
if magic != (1, 0):
raise ValueError("unsupported NumPy format version %s" %
str(magic))
shape, fortran, dtype = read_array_header_1_0(f)
self.data_offset = f.tell()
self.fname = fname
if fortran:
raise ValueError("only C-ordered serialized arrays are supported")
if not dtype.isnative or not dtype.isbuiltin:
raise ValueError("only builtin dtypes with native byte orders supported")
self.shape = shape
self.dtype = dtype
# Add (1,) just in case shape[1:] is ().
self.row_bytes = (reduce(lambda a, b: a * b, (1,) + shape[1:]) *
self.dtype.itemsize)
self.batch_rows = batch_rows
self.fp = NULL
self.buf = <char *>malloc(self.batch_rows * self.row_bytes *
sizeof(char))
if not self.buf:
raise MemoryError()
def pre_fetch(self):
if not self.fp:
self._setup_fp()
self.thread = threading.Thread(target=self._do_pre_fetch)
self.thread.start()
def _do_pre_fetch(self):
# print "Pre-fetching..."
with nogil:
self.n_read = fetch_loop(self.data_offset, self.buf,
self.row_bytes, self.batch_rows,
self.fp)
# print "Done pre-fetching."
def _setup_fp(self):
self.fp = fopen(self.fname, 'rb')
if not self.fp:
raise IOError("Couldn't open file %s" % self.fname)
if fseek(self.fp, self.data_offset, SEEK_SET):
global errno, sys_errlist
raise IOError(sys_errlist[errno])
def next_batch(self):
cdef int i
cdef np.ndarray arr, return_arr
cdef int ndim = len(self.shape)
cdef np.npy_intp *dims
if self.thread is None:
self.pre_fetch()
try:
dims = <np.npy_intp *>malloc(len(self.shape) * sizeof(np.npy_intp))
self.thread.join()
dims[0] = self.n_read
for i in range(1, ndim):
dims[i] = self.shape[i]
arr = np.PyArray_SimpleNewFromData(ndim, dims, self.dtype.num,
self.buf)
return_arr = arr.copy()
self.pre_fetch()
return arr
finally:
free(dims)
def __del__(self):
free(self.buf)
fclose(self.fp)
@larsmans
Copy link

The base_offset should be a long, not a size_t; that's what fseek expects.

@dwf
Copy link
Author

dwf commented Jun 27, 2013

@larsmans Oops, you're right. I'd even declared data_offset as long, just somehow was silently downcasting through a size_t.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment