"""
A reference implementation of the RKPI2 format based on the specification,
this module operates on NumPy arrays to input/output of samples.

This is a trivial example would be like this:
```
>>> import rkpio
>>> import numpy as np
>>> out = rkpio.File('sine.rkp', 'w',
            rkpio.Header(rkpio.Format.F32,
                samplerate=44100, channels=1))
>>> a = [[np.sin(2 * np.pi * 440 * s/out.samplerate)] 
            for s in range(out.samplerate)]
>>> out.write(np.array(a))
```
"""
from enum      import IntEnum
from functools import partial
from types     import SimpleNamespace
from zstandard import ZstdCompressor, ZstdDecompressor
from numpy     import ndarray, array, zeros, frombuffer, dtype, interp

_samplerates  = [8000,12000,22050,32000,44100,64000,96000,192000]
_formatranges = {
    'b': (-128,                 +127),
    'h': (-65536,               +65535),
    'l': (-2147483648,          +2147483647),
    'q': (-9223372036854775808, +9223372036854775807),
    'f': (-1,                   +1),
    'd': (-1,                   +1)
}

def _fmtconvert(a: ndarray, dst: dtype):
    dst = dtype(dst)
    return interp(a, _formatranges[a.dtype.char],
                     _formatranges[dst.char]) \
            .astype(dst)

_isreadable = lambda f: hasattr(f, 'read') or f.read(0) == b''
_iswritable = lambda f: hasattr(f, 'write') or f.write(b'') == 0

class Format(IntEnum):
    """PCM sampleformat to use when encoding to machine
    readable binary format.
    
    Variants prefixed with `F` represent floating point
    and `S` signed. Integers exceeding it represent amount
    of bits required.
    """
    S8  = 0
    S16 = 1
    S32 = 2
    S64 = 3 
    F32 = 4
    F64 = 5

class Util:
    @staticmethod
    def npdtype_to_rkpi2fmt(t: str or dtype) -> Format:
        return Format('bhlqfd'.index(dtype(t).char))

    @staticmethod
    def rkpi2fmt_to_npdtype(t: Format) -> dtype:
        return dtype('bhlqfd'[Format(t).value])

class Header(SimpleNamespace):
    "RKPI2 header which carries metadata about the PCM data."
    def __init__(self, format, samplerate, channels, compressed=False):
        """Initialise RKPI2 header fields with given values.
        
        * sampleformat  PCM sampleformat (Format.S8,S16,S32,S64,F32,F64)
        * samplerate    PCM samplerate (8000,12000,22050,32000,48000,64000,96000,192000)
        * channels      number of audio channels [+1..+8)
        * compressed    is compressed with ZSTD? (True,False)
        """
        if not isinstance(format, Format):
            raise TypeError('invalid sampleformat')
        self.format = format
        
        if not samplerate in _samplerates:
            raise ValueError('invalid samplerate: %d' % samplerate)
        self.samplerate = samplerate

        if not channels in range(1, 8+1):
            raise ValueError('invalid number of channels')
        self.channels = channels
        self.compressed = bool(compressed)

    @staticmethod
    def frombytes(b):
        """Deserialize RKPI2 header from interchangable and
        machine-readable form to a normal Python object."""
        if b[0] >> 2 != 0x3d:
            raise SyntaxError('invalid RKPI2 format identifier')

        return Header(
            Format      ((b[0] &  1) << 2|
                          b[1] >> 6  &  3),
            _samplerates[ b[1] >> 3  &  7],
                         (b[1] &  7) +  1 ,
                          b[0] >> 1  &  1)

    def tobytes(self):
        """Serialize RKPI2 header from a normal Python object
        to an interchangable and machine-readable form."""
        samplerate_index = _samplerates.index(self.samplerate)

        return bytes([0xf4               |
            self.compressed    << 1      |
            self.format.value  >> 2      ,
            (self.format.value &  3) << 6|
            samplerate_index   << 3      |
            self.channels      -  1])

class File:
    """
    Wrap or open a binary filestream and offload the process of
    parsing the header and sampledata.

    PCM samples IO is done with `numpy.ndarray` which is a 2D array.
    Such as: `a = np.array([[0, 1], [2, 3]])`, `a[:,0]` and `a[,:1]`
    are both single channels. Upto `a[:,7]` is allowed else a ValueError
    is raised.
    """
    def __init__(self, f, mode: str='r', hdr: Header=None, complev=11):
        """Initialise underlying operations todo IO with RKPI2 data.

        * f        a file-like object or str or bytes to read data.
        * mode     tells whether to encode or decode ('r'=decode, 'w'=encode).
        * hdr      header fields of RKPI2 (only used while encoding).
        * complev  level of ZSTD compression if it was enabled (+1..+22).
        """
        self.writable = self.readable = \
            lambda self: False

        # file hasn't been already opened open it.
        if isinstance(f, (str, bytes)):
            fs = open(f, mode[0]+'b')
        else: fs = f

        if mode[0] == 'r':
            assert _isreadable(fs), 'unreadable binary file'

            self.readable = lambda: True
            hdr = Header.frombytes(fs.read(2))

            if hdr.compressed:
                fs = ZstdDecompressor().stream_reader(fs)
        
        elif mode[0] == 'w':
            if hdr is None:
                raise TypeError('header fields empty when encoding')

            assert _iswritable(fs), 'unwritable binary file'
            
            self.writable = lambda: True
            fs.write(hdr.tobytes())

            if hdr.compressed:
                fs = ZstdCompressor(abs(int(complev))).stream_writer(fs)

        self._file      = fs
        self.format     = hdr.format
        self.samplerate = hdr.samplerate
        self.channels   = hdr.channels
        self._bytedepth = Util.rkpi2fmt_to_npdtype(hdr.format).itemsize

    def read(self, n=-1, fmt='f') -> ndarray:
        """Read samples from the underlying stream as NumPy array.

        * n    number of samples to read.
        * fmt  data type of samples (same as NumPy's dtype).
        """
        if not self.readable():
            return

        fmt = dtype(fmt)

        t0 = self._bytedepth*self.channels
        a = self._file.read(n*t0)
        assert n % t0 == 0, 'uneven number of channels'

        a = frombuffer(a, dtype=Util.rkpi2fmt_to_npdtype(self.format))
        if a.dtype != fmt:
            a = _fmtconvert(a, fmt)
        
        try:
            return a.reshape((len(a)//self.channels, 
                            self.channels))
        except:
            return zeros((0,0))

    def blocks(self, blksiz, fmt='f'):
        """Provide an iterator over reading sample data from the input
        as series of arbitrarily sized blocks.
        
        * blksiz  size of each block in samples.
        * fmt     data type of samples (same as NumPy's dtype).
        """
        s = zeros((0, 0))
        blksiz = abs(int(blksiz))

        return iter(partial(self.read, blksiz), s)

    def write(self, a: ndarray):
        "Write a 2D `numpy.ndarray` to the underlying stream."
        if not self.writable():
            return

        fmt = Util.rkpi2fmt_to_npdtype(self.format)

        if not isinstance(a, ndarray):
            raise TypeError('invalid sequence to write as audio samples')
        if a.ndim != 2:
            raise SyntaxError('invalid layout of samples')
        if a.shape[1] != self.channels:
            raise SyntaxError('invalid number of channels')

        if a.dtype == fmt:
            a = _fmtconvert(a, fmt)
        
        return self._file.write(a.tobytes())

    def writebuf(self, b):
        "Writes raw bytes to the underlying filestream."
        return self._file.write(b)

    def readbuf(self, n=-1):
        "Reads `n` raw bytes from the underlying filestream."
        return self._file.read(n)