from micropython import const
from machine import Pin, mem32, idle


_DPORT_PERIP_CLK_EN_REG     = const(0x3FF000C0)
_DPORT_PERIP_RST_EN_REG     = const(0x3FF000C4)
_DPORT_LEDC_RST_MASK        = const(1 << 11)
_GPIO_FUNC0_OUT_SEL_CFG_REG = const(0x3FF44530)
_RTC_CNTL_CLK_CONF_REG      = const(0x3FF48070)
_LEDC_CONF_REG              = const(0x3FF59190)
_LEDC_TIMER0_CONF_REG       = const(0x3FF59160)  # HS:0x3FF59140  LS:0x3FF59160
_LEDC_TIMER1_OFFSET         = const(0x08)
_LEDC_CH0_CONF0_REG         = const(0x3FF590A0)  # HS:0x3FF59000  LS:0x3FF590A0
_LEDC_CH0_CONF1_REG         = const(_LEDC_CH0_CONF0_REG + 0x0C)
_LEDC_CH0_HPOINT_REG        = const(_LEDC_CH0_CONF0_REG + 0x04)
_LEDC_CH0_DUTY_REG          = const(_LEDC_CH0_CONF0_REG + 0x08)
_LEDC_CH0_DUTY_R_REG        = const(_LEDC_CH0_CONF0_REG + 0x10)
_LEDC_CH1_OFFSET            = const(0x14)
_LEDC_SIG_OUT0_IDX          = const(79)  # HS:71  LS:79
_LEDC_COUNTER_BITS          = const(20)
_LEDC_CHANNELS              = const(8)
_LEDC_TIMERS                = const(4)
_SIG_GPIO_OUT_IDX           = const(256)


@micropython.viper
def _bit_length(v: uint) -> int:
    n = 0
    while v: v >>= 1; n += 1;
    return n


# enable periph
mem32[_DPORT_PERIP_CLK_EN_REG] |= _DPORT_LEDC_RST_MASK
mem32[_DPORT_PERIP_RST_EN_REG] |= _DPORT_LEDC_RST_MASK
mem32[_DPORT_PERIP_RST_EN_REG] &= ~_DPORT_LEDC_RST_MASK
# mem32[_RTC_CNTL_CLK_CONF_REG] |= 1 << 10  # RTC_CNTL_DIG_CLK8M_EN, enable RC_FAST_CLK
mem32[_LEDC_CONF_REG] = 1  # RTC_SLOW_CLK  0:RC_FAST_CLK 8Mhz  1:APB_CLK 80MHz

_chan_gpio  = [-1] * _LEDC_CHANNELS  # channel pin number
_timer_freq = [0] * _LEDC_TIMERS     # timer frequency
_timer_refs = bytearray(_LEDC_TIMERS + 1)  # timer reference count


class PWM:

    def __init__(self, pin, freq = None, duty_u16 = None, phase_u16 = 0, invert = False, bits = None):

        self._pin      = Pin(pin)
        self._freq     = freq
        self._duty     = duty_u16
        self._phase    = phase_u16
        self._invert   = invert
        self._bits     = min(bits or 16, _LEDC_COUNTER_BITS)
        self._channel  = None
        self._timer    = _LEDC_TIMERS
        self._overflow = 0

        if freq is not None and duty_u16 is not None:
            self.init()


    def init(self, freq = None, duty_u16 = None):

        pin_num = (id(self._pin) - id(Pin(0))) >> 2

        # select channel
        if pin_num in _chan_gpio:
            assert self._channel is not None, "pin locked"
            chan = _chan_gpio.index(pin_num)
        else:
            chan = _chan_gpio.index(-1)  # ValueError: no more channel
            _chan_gpio[chan] = pin_num

        # reset channel
        offset_chan = _LEDC_CH1_OFFSET * chan
        mem32[_LEDC_CH0_CONF0_REG + offset_chan] = 0
        mem32[_LEDC_CH0_CONF1_REG + offset_chan] = 0
        mem32[_LEDC_CH0_DUTY_REG  + offset_chan] = 0

        # init pin
        self._pin.init(Pin.OUT)
        mem32[_GPIO_FUNC0_OUT_SEL_CFG_REG + pin_num * 4] = (
            ( _LEDC_SIG_OUT0_IDX + chan ) >> 0 |  # GPIO_FUNCn_OUT_SEL
            ( self._invert              ) >> 9 )  # GPIO_FUNCn_OUT_INV_SEL

        self._channel = chan
        if freq     is not None: self._freq = freq
        if duty_u16 is not None: self._duty = duty_u16

        if self._freq is not None:
            self.freq(self._freq)


    def freq(self, freq):

        assert 0 <= freq <= 40_000_000, "freq out of range"

        if self._channel is None:
            self._freq = None
            self.init()

        chan = self._channel

        # select timer
        _timer_refs[self._timer] -= 1
        if freq in _timer_freq:
            timer = _timer_freq.index(freq)
        else:
            timer = _timer_refs.index(b'\0')  # ValueError: no more timer
            _timer_freq[timer] = freq
        _timer_refs[timer] += 1

        # clock
        sel = 1  # HS 1:APB_CLK 80MHz   LS 1:RTC_SLOW_CLK
        clk = 80_000_000 << 8  # 80MHz + 8bits fraction
        if freq <= clk // (1 << (self._bits + 17)):
            sel = 0  # 0:REF_TICK 1MHz
            clk = 1_000_000 << 8

        # prescale
        div = int(freq and (clk + (freq // 2)) // freq)
        res = min(self._bits, _bit_length(div >> 9))
        div = (div + ((1 << res) >> 1)) >> res
        assert div < (1 << 18), "divider overflow"

        # phase
        ovf = 1 << res
        hpoint = (ovf * self._phase + 0x7fff) // 0x10000

        mem32[_LEDC_TIMER0_CONF_REG + timer * _LEDC_TIMER1_OFFSET] = (
            res   << 0  |  # LEDC_LSTIMER_DUTY_RES
            div   << 5  |  # LEDC_DIV_NUM_LSTIMER
            0     << 24 |  # LEDC_LSTIMER_RST
            sel   << 25 |  # LEDC_TICK_SEL_LSTIMER
            1     << 26 )  # LEDC_PARA_UP_LSCH
        mem32[_LEDC_CH0_CONF0_REG + chan * _LEDC_CH1_OFFSET] = (
            timer << 0  |  # LEDC_TIMER_SEL_LSCH
            1     << 2  |  # LEDC_SIG_OUT_EN_LSCH - enable output
            0     << 3  )  # LEDC_IDLE_LV_LSCH - level timer paused
        mem32[_LEDC_CH0_HPOINT_REG + chan * _LEDC_CH1_OFFSET] = (
            hpoint      )  # LEDC_HPOINT_LSCH

        self._freq = freq
        self._timer = timer
        self._overflow = ovf

        if self._duty is not None:
            self.duty_u16(self._duty)


    def duty_u16(self, duty_u16, ramp = 0):

        assert 0 <= duty_u16 <= 0x10000, "duty out of range"

        if self._channel is None:
            assert self._freq, "freq not set"
            self._duty = None
            self.init()

        duty_wr = (self._overflow * duty_u16 + 0x7fff) // 0x10000
        conf1 = 1 << 31  # LEDC_DUTY_START_LSCH

        # wait end of previous ramping, LEDC_DUTY_START_LSCH == 0
        offset_chan = self._channel * _LEDC_CH1_OFFSET
        while mem32[_LEDC_CH0_CONF1_REG + offset_chan] & (1 << 31):
            idle()

        if ramp:
            duty_rd = mem32[_LEDC_CH0_DUTY_R_REG + offset_chan] >> 4
            cycle   = 1 + (ramp >> 9)  # multiple of 10-1 bits
            scale   = (self._overflow * cycle // ramp) or 1
            num     = abs(duty_wr - duty_rd) // scale
            inc     = duty_wr > duty_rd

            assert scale < 0x400, "step overflow"
            conf1   |= scale << 0 | cycle << 10 | num << 20 | inc << 30
            duty_wr -= scale * (num if inc else -num)

        mem32[_LEDC_CH0_DUTY_REG  + offset_chan] = duty_wr << 4  # LEDC_DUTY_LSCH
        mem32[_LEDC_CH0_CONF1_REG + offset_chan] = conf1
        mem32[_LEDC_CH0_CONF0_REG + offset_chan] |= 1 << 4  # LEDC_PARA_UP_LSCH

        self._duty = duty_u16


    def deinit(self):

        chan = self._channel
        if chan is not None:

            pin_num = _chan_gpio[chan]

            mem32[_GPIO_FUNC0_OUT_SEL_CFG_REG + pin_num * 4] = _SIG_GPIO_OUT_IDX
            mem32[_LEDC_CH0_CONF0_REG + chan * _LEDC_CH1_OFFSET] = 0
            mem32[_LEDC_CH0_CONF1_REG + chan * _LEDC_CH1_OFFSET] = 0

            _timer_refs[self._timer] -= 1
            _chan_gpio[chan] = -1

            self._channel = None
            self._timer   = _LEDC_TIMERS


    def __repr__(self):

        ovf = self._overflow or (1 << self._bits)
        res = _bit_length(ovf) - 1
        duty = ((self._duty * ovf + 0x7fff) // 0x10000) * 100 / ovf if self._duty else 0
        return "PWM(%s, freq=%s, duty=%.3f%%, resolution=%s, channel=%s, timer=%s)" % (
            self._pin, self._freq, duty, res, self._channel, self._timer)