Skip to content

Instantly share code, notes, and snippets.

@tripulse
Created May 24, 2020 02:02
Show Gist options
  • Save tripulse/c396836567274aa8b937215e551c1382 to your computer and use it in GitHub Desktop.
Save tripulse/c396836567274aa8b937215e551c1382 to your computer and use it in GitHub Desktop.
RKPI2 implementation in pure Python.
"""
A reference implementation of the RKPI2 format based on the specification,
this module operates on NumPy arrays to input/output of samples.
This is a trivial example would be like this:
```
>>> import rkpio
>>> import numpy as np
>>> out = rkpio.File('sine.rkp', 'w',
rkpio.Header(rkpio.Format.F32,
samplerate=44100, channels=1))
>>> a = [[np.sin(2 * np.pi * 440 * s/out.samplerate)]
for s in range(out.samplerate)]
>>> out.write(np.array(a))
```
"""
from enum import IntEnum
from functools import partial
from types import SimpleNamespace
from zstandard import ZstdCompressor, ZstdDecompressor
from numpy import ndarray, array, zeros, frombuffer, dtype, interp
_samplerates = [8000,12000,22050,32000,44100,64000,96000,192000]
_formatranges = {
'b': (-128, +127),
'h': (-65536, +65535),
'l': (-2147483648, +2147483647),
'q': (-9223372036854775808, +9223372036854775807),
'f': (-1, +1),
'd': (-1, +1)
}
def _fmtconvert(a: ndarray, dst: dtype):
dst = dtype(dst)
return interp(a, _formatranges[a.dtype.char],
_formatranges[dst.char]) \
.astype(dst)
_isreadable = lambda f: hasattr(f, 'read') or f.read(0) == b''
_iswritable = lambda f: hasattr(f, 'write') or f.write(b'') == 0
class Format(IntEnum):
"""PCM sampleformat to use when encoding to machine
readable binary format.
Variants prefixed with `F` represent floating point
and `S` signed. Integers exceeding it represent amount
of bits required.
"""
S8 = 0
S16 = 1
S32 = 2
S64 = 3
F32 = 4
F64 = 5
class Util:
@staticmethod
def npdtype_to_rkpi2fmt(t: str or dtype) -> Format:
return Format('bhlqfd'.index(dtype(t).char))
@staticmethod
def rkpi2fmt_to_npdtype(t: Format) -> dtype:
return dtype('bhlqfd'[Format(t).value])
class Header(SimpleNamespace):
"RKPI2 header which carries metadata about the PCM data."
def __init__(self, format, samplerate, channels, compressed=False):
"""Initialise RKPI2 header fields with given values.
* sampleformat PCM sampleformat (Format.S8,S16,S32,S64,F32,F64)
* samplerate PCM samplerate (8000,12000,22050,32000,48000,64000,96000,192000)
* channels number of audio channels [+1..+8)
* compressed is compressed with ZSTD? (True,False)
"""
if not isinstance(format, Format):
raise TypeError('invalid sampleformat')
self.format = format
if not samplerate in _samplerates:
raise ValueError('invalid samplerate: %d' % samplerate)
self.samplerate = samplerate
if not channels in range(1, 8+1):
raise ValueError('invalid number of channels')
self.channels = channels
self.compressed = bool(compressed)
@staticmethod
def frombytes(b):
"""Deserialize RKPI2 header from interchangable and
machine-readable form to a normal Python object."""
if b[0] >> 2 != 0x3d:
raise SyntaxError('invalid RKPI2 format identifier')
return Header(
Format ((b[0] & 1) << 2|
b[1] >> 6 & 3),
_samplerates[ b[1] >> 3 & 7],
(b[1] & 7) + 1 ,
b[0] >> 1 & 1)
def tobytes(self):
"""Serialize RKPI2 header from a normal Python object
to an interchangable and machine-readable form."""
samplerate_index = _samplerates.index(self.samplerate)
return bytes([0xf4 |
self.compressed << 1 |
self.format.value >> 2 ,
(self.format.value & 3) << 6|
samplerate_index << 3 |
self.channels - 1])
class File:
"""
Wrap or open a binary filestream and offload the process of
parsing the header and sampledata.
PCM samples IO is done with `numpy.ndarray` which is a 2D array.
Such as: `a = np.array([[0, 1], [2, 3]])`, `a[:,0]` and `a[,:1]`
are both single channels. Upto `a[:,7]` is allowed else a ValueError
is raised.
"""
def __init__(self, f, mode: str='r', hdr: Header=None, complev=11):
"""Initialise underlying operations todo IO with RKPI2 data.
* f a file-like object or str or bytes to read data.
* mode tells whether to encode or decode ('r'=decode, 'w'=encode).
* hdr header fields of RKPI2 (only used while encoding).
* complev level of ZSTD compression if it was enabled (+1..+22).
"""
self.writable = self.readable = \
lambda self: False
# file hasn't been already opened open it.
if isinstance(f, (str, bytes)):
fs = open(f, mode[0]+'b')
else: fs = f
if mode[0] == 'r':
assert _isreadable(fs), 'unreadable binary file'
self.readable = lambda: True
hdr = Header.frombytes(fs.read(2))
if hdr.compressed:
fs = ZstdDecompressor().stream_reader(fs)
elif mode[0] == 'w':
if hdr is None:
raise TypeError('header fields empty when encoding')
assert _iswritable(fs), 'unwritable binary file'
self.writable = lambda: True
fs.write(hdr.tobytes())
if hdr.compressed:
fs = ZstdCompressor(abs(int(complev))).stream_writer(fs)
self._file = fs
self.format = hdr.format
self.samplerate = hdr.samplerate
self.channels = hdr.channels
self._bytedepth = Util.rkpi2fmt_to_npdtype(hdr.format).itemsize
def read(self, n=-1, fmt='f') -> ndarray:
"""Read samples from the underlying stream as NumPy array.
* n number of samples to read.
* fmt data type of samples (same as NumPy's dtype).
"""
if not self.readable():
return
fmt = dtype(fmt)
t0 = self._bytedepth*self.channels
a = self._file.read(n*t0)
assert n % t0 == 0, 'uneven number of channels'
a = frombuffer(a, dtype=Util.rkpi2fmt_to_npdtype(self.format))
if a.dtype != fmt:
a = _fmtconvert(a, fmt)
try:
return a.reshape((len(a)//self.channels,
self.channels))
except:
return zeros((0,0))
def blocks(self, blksiz, fmt='f'):
"""Provide an iterator over reading sample data from the input
as series of arbitrarily sized blocks.
* blksiz size of each block in samples.
* fmt data type of samples (same as NumPy's dtype).
"""
s = zeros((0, 0))
blksiz = abs(int(blksiz))
return iter(partial(self.read, blksiz), s)
def write(self, a: ndarray):
"Write a 2D `numpy.ndarray` to the underlying stream."
if not self.writable():
return
fmt = Util.rkpi2fmt_to_npdtype(self.format)
if not isinstance(a, ndarray):
raise TypeError('invalid sequence to write as audio samples')
if a.ndim != 2:
raise SyntaxError('invalid layout of samples')
if a.shape[1] != self.channels:
raise SyntaxError('invalid number of channels')
if a.dtype == fmt:
a = _fmtconvert(a, fmt)
return self._file.write(a.tobytes())
def writebuf(self, b):
"Writes raw bytes to the underlying filestream."
return self._file.write(b)
def readbuf(self, n=-1):
"Reads `n` raw bytes from the underlying filestream."
return self._file.read(n)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment