#!/usr/bin/python3


'''Create video of predictive model distribution

'''

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation

RSTATE = np.random.RandomState(0)


def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from sine model

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''

    x = RSTATE.uniform(-2*np.pi, 2*np.pi, count)
    y = np.sin(x) + RSTATE.normal(0, sigma, len(x))
    return x, y

def design_matrix(x, M):
    '''Generate \Phi design matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi design matrix

    '''

    ret = np.zeros((len(x), M+1), dtype=np.float64)
    for i, x_ in enumerate(x):
        b = basis(x_, M)
        ret[i, :] = b
    return ret

def gram_matrix(x, M):
    '''Return \phi^{t} * \phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi^{t}*Phi matrix
            {np.1darray} Phi matrix

    '''
    p = design_matrix(x, M)
    return p.transpose().dot(p), p

def basis(x, M):
    '''Polynomial basis function of M-th degree

    @param {float} location to evaluate the basis functions
    @param {int} M degree of the basis functions
    @return {np.1darray} basis function values at `x`

    '''

    # Include bias
    N = M+1
    ret = np.zeros(N, dtype=np.float64)
    for i in range(N):
        ret[i] = x**i
    return ret

def train(x, t, M, beta, alpha):
    '''Return mean and covariance of the *posterioir* distribution

    This assumes the prior distribution
        p(w) = p(w| 0, alpha*I)

    @param {np.1darray} x observation points
    @param {np.1darray} t observation values
    @param {int} M model complexity (polynomial degree)
    @param {float} beta model variance
    @param {float} alpha weight parameter prior variance
    @return {np.1darray} posterior mean
            {np.1darray} posterior covariance

    '''
    gramMatrix, Phi = gram_matrix(x, M)
    Sn_inv = alpha * np.eye(M+1, M+1) + beta * gramMatrix
    Sn = np.linalg.inv(Sn_inv)
    mN = beta * Sn.dot(Phi.transpose().dot(t))
    return mN, Sn

def predict(x, M, mu):
    '''Predict the linear polynomial regression values for input x

    @param {iterable} x locations to predict values
    @param {int} M model complexity
    @param {np.1darray} mu predictive distribution mean
    @return {np.1darray} predicted mean for x

    '''
    mean = []
    for x_ in x:
        p = basis(x_, M)
        mean.append(np.inner(mu, p))
    return np.array(mean)

def predict_std(x, M, Sn, beta):
    '''Associated standard deviation at model prediction at input x

    @param {iterable} x locations to evaluate the standard deviation
    @param {int} M model complexity
    @param {np.2darray} Sn predictive distribution covariance
    @param {float} beta model variance
    @return {np.1darray} predicted standard deviation at x

    '''
    std = []
    for x_ in x:
        p = basis(x_, M)
        std.append(np.sqrt(1.0/beta + p.dot(Sn.dot(p))))
    return np.array(std)

def update_plot(frame, data, x_plot, M, std_offset, beta, alpha, fig_data):
    '''FuncAnimation plot update function

    Add new random data to the plot and updates plots

    @param {int} frame frame number being rendered
    @param {np.1darray} data.x observation points. Updated after this call
           {np.1darray} data.t observation values. Updated after this call
    @param {np.1darray} x_plot range of values to eval the model estimate
    @param {int} M model complexity (polynomial degree)
    @param {int} std_offset determines +/- stds to shade mean in plot
    @param {float} beta model variance
    @param {float} alpha weight parameter prior variance
    @param {matplotlib.fig} fig_data.fig parent figure object
           {matplotlib.Line2D} fig_data.scatter_plot data scatter plot obj
           {matplotlib.Line2D} fig_data.mean_plot model mean plot obj
           {matplotlib.Line2D} fig_data.fill_plot shaded region of std obj

    '''
    print ('frame {}'.format(frame))
    # fetch new pseudo random data
    x_, t_ = get_data(2)
    data['x'] = np.append(data['x'], x_)
    data['t'] = np.append(data['t'], t_)

    # Remove fill region
    fig_data['fill_plot'].remove()

    # Update scatter plot
    fig_data['scatter_plot'].set_data(data['x'],
                                      data['t'])
    # Update mean plot
    mn, sn = train(data['x'], data['t'], M, beta, alpha)
    dist = predict(x_plot, M, mn)
    fig_data['mean_plot'].set_data(x_plot, dist)

    # Update standard dev region
    ax = fig_data['ax']
    std = predict_std(x_plot, M, sn, beta)
    fill_obj = ax.fill_between(x_plot,
                               dist - std_offset*std,
                               dist + std_offset*std,
                               color='red', alpha=0.35,
                               label=r'$E \pm {}\sigma$'.format(std_offset))
    fig_data['fill_plot'] = fill_obj
    return []


def main():
    # Model variance
    BETA = 1
    # Weight parameter prior variance
    ALPHA = 1
    # Model complexity
    M = 10

    fig, ax = plt.subplots()
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    ax.set_title(r'$M = {}, \beta = {}, \alpha = {}$'.format(M, BETA, ALPHA))
    x, t = get_data(2)

    x_plot = np.linspace(-2*np.pi, 2*np.pi, 100)

    mn, sn = train(x, t, M, BETA, ALPHA)
    dist = predict(x_plot, M, mn)
    std = predict_std(x_plot, M, sn, BETA)
    STD_OFFSET = 1
    data_obj, = ax.plot(x, t,
                        linestyle='', marker='o', markersize=4, label='Data')

    mean_obj, = ax.plot(x_plot, np.sin(x_plot),
                        color='m', linestyle='--', label='Target')

    mean_obj, = ax.plot(x_plot, dist,
                        color='g', linewidth=2, label='Model mean')

    fill_obj = ax.fill_between(x_plot,
                               dist - STD_OFFSET*std,
                               dist + STD_OFFSET*std,
                               color='red', alpha=0.35,
                               label=r'$E \pm {}\sigma$'.format(STD_OFFSET))

    ax.grid(True)
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')
    ax.set_xlim([-2*np.pi, 2*np.pi])
    ax.set_ylim([-1.5, 1.5])

    data = {
        'x': x,
        't': t
    }

    fig_data = {
        'ax': ax,
        'fig': fig,
        'scatter_plot': data_obj,
        'fill_plot': fill_obj,
        'mean_plot': mean_obj
    }

    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)
    c_ani = animation.FuncAnimation(fig, update_plot, 100,
                                    fargs=(data, x_plot, M, STD_OFFSET,
                                           BETA, ALPHA, fig_data),
                                    interval=50, blit=True)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../video/predictive-distribution.mp4')
    c_ani.save(fn, writer=writer)

if __name__ == '__main__':
    main()