#!/usr/bin/python3

'''Create video of weight parameter probablity estimation

'''

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation


RSTATE = np.random.RandomState(0)

# Target model coefficients
TARGET = [0, 1]

def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from linear model

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''

    x = RSTATE.uniform(-1, 1, count)
    y = TARGET[0] + TARGET[1] * x + RSTATE.normal(0, sigma, len(x))
    return x, y

def gaussianfn(mu, sigma2):
    '''Return a gaussian function generator

    @param {float} mu numeric or a vector: mean
    @param {float|np.ndarray} sigma2 numeric or list: variance (sigma^2)
    @return {function} gaussian function generator

    '''

    if (isinstance(sigma2, float) or isinstance(sigma2, int)):
        coeff = 1.0/np.sqrt(2*np.pi*sigma2)
        return lambda x: coeff*np.exp(-(x-mu)**2/(2*sigma2))

    D = sigma2.size
    detSigma = np.linalg.det(sigma2)
    coeff = 1.0/(2*np.pi)**(D/2.0) * detSigma**0.5
    sigmaInverse = np.linalg.inv(sigma2)
    return lambda x: coeff*np.exp(-0.5*
                                  np.inner(
                                      (x-mu),
                                      np.dot(sigmaInverse, x-mu),
                                  ))

def basis(x, M):
    '''Polynomial basis function of M-th degree

    @param {float} location to evaluate the basis functions
    @param {int} M degree of the basis functions
    @return {np.1darray} basis function values at `x`

    '''

    # Include bias
    N = M+1
    ret = np.zeros(N, dtype=np.float64)
    for i in range(N):
        ret[i] = x**i
    return ret

def phi(x, M):
    '''Generate \Phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi matrix

    '''

    ret = np.zeros((len(x), M+1), dtype=np.float64)
    for i, x_ in enumerate(x):
        b = basis(x_, M)
        ret[i, :] = b
    return ret

def phit_phi(x, M):
    '''Return \Phi^{t} * \Phi matrix

    This matrix includes contribution from the full dataset

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi^{t}*Phi matrix
            {np.1darray} Phi matrix

    '''

    # ret = np.zeros((M, M), dtype=np.float64)
    p = phi(x, M)
    return p.transpose().dot(p), p

def posterior_props(x, t, beta, alpha):
    '''Return posterior distribution mean and covariance (Gaussian)

    This assumes the prior distribution
        p(w) = p(w| 0, alpha*I)

    @param {np.1darray} x observation points
    @param {np.1darray} t observation values
    @param {float} beta model variance
    @param {float} alpha weight parameter prior variance
    @return {np.1darray} posterior mean
            {np.1darray} posterior covariance

    '''
    PhiTPhi, Phi = phit_phi(x, 1)
    Sn_inv = alpha * np.eye(2, 2) + beta * PhiTPhi
    Sn = np.linalg.inv(Sn_inv)
    mN = beta * Sn.dot(Phi.transpose().dot(t))
    return mN, Sn

def gaussian_grid(w0, w1, gaussian):
    '''Evaluate the gaussian function over [w0 x w1] grid

    @param {np.1darray} w0 range for w0 parameter space
    @param {np.1darray} w1 range for w1 parameter space
    @param {func} gaussian callable Gaussian function
    @return {np.2darray} evaluated Gaussian at [w0 x w1] grid

    '''
    ret = np.zeros((len(w0), len(w1)), dtype=np.float64)
    for i, _ in enumerate(w0):
        for j, _ in enumerate(w1):
            ret[j, i] = gaussian(np.array([w0[i], w1[j]]))
    return ret

def posterior_grid(x, t, w0, w1, beta, alpha):
    '''Evaluate the weight parameter estimate over [w0 x w1]

    @param {np.1darray} x observation points
    @param {np.1darray} t observation values
    @param {np.1darray} w0 range for w0 parameter space
    @param {np.1darray} w1 range for w1 parameter space
    @param {float} beta model variance
    @param {float} alpha weight parameter prior variance
    @return {np.2darray} evaluated weight estimate prior over [w0 x w1]
            {np.2darray} mean of the weight posterior

    '''
    mN, Sn = posterior_props(x, t, beta, alpha)
    g = gaussianfn(mN, Sn)
    return gaussian_grid(w0, w1, g), mN

def predict_1d(x, w):
    '''Predict the linear polynomial regression values for input x

    @param {iterable} x locations to predict values
    @param {array} w array of weights
    @return {np.1darray} predicted values for x

    '''
    y = []
    for x_ in iter(x):
        y.append(w[0] + w[1] * x_)
    return np.array(y)

def update_plot(frame, data, W0, W1, beta, alpha, fig_data):
    '''FuncAnimation plot update function

    Adds new random data to the plot and updates plots

    @param {int} frame frame number being rendered
    @param {np.1darray} data.x observation points. Updated after this call
           {np.1darray} data.t observation values. Updated after this call
    @param {np.ndarray} W0 2d parameter space for w0
    @param {np.ndarray} W1 2d parameter space for w1
    @param {float} beta model variance
    @param {float} alpha weight parameter prior variance
    @param {matplotlib.fig} fig_data.fig parent figure object
           {matplotlib.axes.Axes} array of subplot axes,
                                  [0]: contour plot, [1]: scatter plot
           {matplotlib.cm.ScalarMappable} fig_data.cbar contour map object
    '''

    print('frame {}'.format(frame))

    # Remove existing contour plot to speed up replotting
    for c in fig_data['cset'].collections:
        c.remove()

    cplot_ax = fig_data['ax'][0]
    # fetch 5 more random points
    x_, t_ = get_data(5)
    # Append new data to the existing dataset
    data['x'] = np.append(data['x'], x_)
    data['t'] = np.append(data['t'], t_)
    p, w_mean = posterior_grid(data['x'], data['t'], W0[0, :], W1[:, 1], beta, alpha)
    cf = cplot_ax.contourf(W0, W1, p, 50)

    # Update data scatter plot
    fig_data['scatter_plot'].set_data(data['x'], data['t'])

    # Update model plot
    ax = fig_data['ax'][1]
    xlim = ax.get_xlim()
    y = predict_1d(xlim, w_mean)
    fig_data['model_plot'].set_data(xlim, y)

    # Update colorbar map -- seems to be the least hacky way of acheiving this
    fig_data['cset'] = cf
    cmin_max = [np.min(p), np.max(p)]
    fig_data['cbar_map'].set_clim(cmin_max[0], cmin_max[1])
    # Hack to force update the colorbar
    plt.draw()
    return cf.collections

def main():
    BETA = 1
    ALPHA = 1
    # Range where parameter space is plotted
    w0 = np.linspace(-1, 1, 100)
    w1 = np.linspace(0, 1.5, 100)
    W0, W1 = np.meshgrid(w0, w1)

    # Configure the subplots
    fig, ax_array = plt.subplots(1, 2)
    ax = ax_array[0]
    ax.set_xlabel('$w_0$')
    ax.set_ylabel('$w_1$')
    ax.set_xlim([w0[0], w0[-1]])
    ax.set_ylim([w1[0], w1[-1]])
    ax.set_title(r'$\beta = {}, \alpha = {}$'.format(BETA, ALPHA))

    x, t = get_data(10)
    scatter_ax = ax_array[1]
    scatter_ax.grid(True)
    scatter_ax.set_xlabel('$x$')
    scat_obj, = ax_array[1].plot(x, t, linestyle='', marker='o', markersize=3, label='Data')

    model_obj, = ax_array[1].plot([], [], linewidth=2, label='Model mean')
    scatter_ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')

    x_vals = [-1, 1]
    target_obj, = ax_array[1].plot(x_vals, predict_1d(x_vals, TARGET),
                                   color='r',
                                   linestyle='--',
                                   label='Target')
    scatter_ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')

    # Model data
    data = {'x': x, 't': t}
    mN, Sn = posterior_props(x, t, BETA, ALPHA)
    g = gaussianfn(mN, Sn)
    p = gaussian_grid(w0, w1, g)
    contour_set = plt.contourf(W0, W1, p, 100)

    cbar_map = plt.cm.ScalarMappable()
    cbar_map.set_clim(np.min(p), np.max(p))
    cbar = fig.colorbar(cbar_map, ax=ax, format='%.0e')

    # prevent overlap of axes objects
    plt.tight_layout()

    fig_data = {'fig': fig,
                'ax': ax_array,
                'cbar_map': cbar_map,
                'cset': contour_set,
                'scatter_plot': scat_obj,
                'model_plot': model_obj}

    # Set up formatting for the movie files
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
    c_ani = animation.FuncAnimation(fig, update_plot, 100,
                                    fargs=(data, W0, W1, BETA, ALPHA, fig_data),
                                    interval=50, blit=True)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../video/parameter-distribution.mp4')
    c_ani.save(fn, writer=writer)


if __name__ == '__main__':
    main()