"""Adam Langevin Dynamics for optax, by Katherine Crowson."""

from typing import NamedTuple

import jax
import jax.numpy as jnp
import optax


def unwrap_schedule(scalar_or_schedule, count):
    """Unwraps a scalar or schedule into a scalar."""
    if callable(scalar_or_schedule):
        return scalar_or_schedule(count)
    return scalar_or_schedule


def keys_like_tree(key, tree):
    """Generates a tree of random keys with the same structure as `tree`."""
    leaves, treedef = jax.tree_util.tree_flatten(tree)
    return jax.tree_util.tree_unflatten(treedef, jax.random.split(key, len(leaves)))


def noise_like_tree(key, tree):
    """Generates standard normal noise with the same structure as `tree`."""
    keys = keys_like_tree(key, tree)
    return jax.tree_map(
        lambda x, key: jax.random.normal(key, x.shape, x.dtype), tree, keys
    )


def inverse_schedule(init_value, gamma=1.0, power=1.0):
    """Constructs an inverse decay schedule (for the SGLD convergence
    guarantee).

    Args:
        init_value: the initial value of the schedule.
        gamma: the multiplicative factor.
        power: the power of the inverse decay.
    """

    def schedule(count):
        return init_value * (1.0 + count * gamma) ** -power

    return schedule


def make_priors_flax(params, prior_fun=None):
    """Constructs the prior mean and variance trees for a Flax model.

    Args:
        params: the Flax model parameters.
        prior_fun: a function that takes a path and value and returns a tuple of
        (mean, variance) for the prior distribution of the value. If None, uses a prior
        corresponding to the Flax default initialization.
    """
    import flax

    def default_prior_fun(path, value):
        if path[-1] == "bias":
            return jnp.array(0.0), jnp.array(1.0)
        if path[-1] == "embedding":
            fan_out = value.shape[-1]
            return jnp.array(0.0), jnp.array(1.0 / fan_out)
        if path[-1] == "kernel":
            fan_in = value.shape[-2]
            return jnp.array(0.0), jnp.array(1.0 / fan_in)
        if path[-1] == "scale":
            return jnp.array(1.0), jnp.array(1.0)
        raise ValueError(f'Unknown param type: {"/".join(path)}')

    prior_fun = prior_fun or default_prior_fun

    priors = flax.traverse_util.path_aware_map(prior_fun, params)
    means = flax.core.FrozenDict(
        jax.tree_map(lambda x: x[0], priors, is_leaf=lambda x: isinstance(x, tuple))
    )
    variances = flax.core.FrozenDict(
        jax.tree_map(lambda x: x[1], priors, is_leaf=lambda x: isinstance(x, tuple))
    )
    return means, variances


def prior_potential(tree, priors):
    """Computes the potential of a prior distribution evaluated at a tree of
    parameters."""
    return jax.tree_util.tree_reduce(
        jnp.add,
        jax.tree_map(
            lambda x, m, v: 0.5 * jnp.sum((x - m) ** 2 / v), tree, priors[0], priors[1]
        ),
        0.0,
    )


def prior_sample(key, params, priors):
    """Samples parameters from the prior distribution."""
    noise = noise_like_tree(key, params)
    return jax.tree_map(
        lambda x, m, v: x * jnp.sqrt(v) + m, noise, priors[0], priors[1]
    )


class SGLDState(NamedTuple):
    count: jax.Array
    key: jax.Array


def sgld(learning_rate, key, priors, tau=1.0):
    """Stochastic Gradient Langevin Dynamics.

    The gradient provided to the update function should be a stochastic estimate
    of the gradient of the true unnormalized negative log likelihood. (That is,
    it should already be divided by the batch size and multiplied by the number of
    samples in the training set.)

    Args:
        learning_rate: a scalar or a function that maps the step count to a scalar.
        key: a JAX PRNG key.
        priors: a tuple of two trees with the same structure of the parameters,
            which contain the means and variances of the diagonal normal prior.
        tau: the temperature of the Langevin dynamics.
    """

    prior_means, prior_variances = priors

    def init_fn(params):
        return SGLDState(count=jnp.zeros([], jnp.int32), key=key)

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError("No params provided to update_fn.")

        posterior_grads = jax.tree_map(
            lambda x, m, v, u: (x - m) / v + u,
            params,
            prior_means,
            prior_variances,
            updates,
        )

        key, subkey = jax.random.split(state.key)
        noise = noise_like_tree(subkey, params)

        lr = unwrap_schedule(learning_rate, state.count)
        updates = jax.tree_map(
            lambda g, n: -lr * g + jnp.sqrt(2 * lr * tau) * n,
            posterior_grads,
            noise,
        )

        state = SGLDState(count=state.count + 1, key=key)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)


class AdamLDState(NamedTuple):
    count: jax.Array
    key: jax.Array
    m: optax.Updates
    v: optax.Updates


def adamld(learning_rate, key, priors, tau=1.0, b1=0.9, b2=0.99, eps=1e-8):
    """Adam Langevin Dynamics, by Katherine Crowson.

    The gradient provided to the update function should be a stochastic estimate
    of the gradient of the true unnormalized negative log likelihood. (That is,
    it should already be divided by the batch size and multiplied by the number of
    samples in the training set.)

    Args:
        learning_rate: a scalar or a function that maps the step count to a scalar.
        key: a JAX PRNG key.
        priors: a tuple of two trees with the same structure of the parameters,
            which contain the means and variances of the diagonal normal prior.
        tau: the temperature of the Langevin dynamics.
        b1: the exponential decay rate for the first moment estimate.
        b2: the exponential decay rate for the second moment estimate.
        eps: a small constant for numerical stability.
    """

    prior_means, prior_variances = priors

    def init_fn(params):
        m = jax.tree_map(jnp.zeros_like, params)
        v = jax.tree_map(jnp.zeros_like, params)
        return AdamLDState(
            count=jnp.zeros([], jnp.int32),
            key=key,
            m=m,
            v=v,
        )

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError("No params provided to update_fn.")

        m = jax.tree_map(lambda x, u: b1 * x + (1 - b1) * u, state.m, updates)
        v = jax.tree_map(
            lambda x, u: b2 * x + (1 - b2) * u * u.conj(), state.v, updates
        )
        m_hat = jax.tree_map(lambda x: x / (1 - b1 ** (state.count + 1)), m)
        v_hat = jax.tree_map(lambda x: x / (1 - b2 ** (state.count + 1)), v)

        posterior_grads = jax.tree_map(
            lambda x, m, v, u: (x - m) / v + u,
            params,
            prior_means,
            prior_variances,
            m_hat,
        )
        precond = jax.tree_map(
            lambda v, pv: 1 / (jnp.sqrt(v + pv**-2.0) + eps), v_hat, prior_variances
        )

        key, subkey = jax.random.split(state.key)
        noise = noise_like_tree(subkey, params)

        lr = unwrap_schedule(learning_rate, state.count)
        updates = jax.tree_map(
            lambda p, g, n: -lr * p * g + jnp.sqrt(jnp.abs(2 * lr * p * tau)) * n,
            precond,
            posterior_grads,
            noise,
        )

        state = AdamLDState(count=state.count + 1, key=key, m=m, v=v)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)