#!/usr/bin/python3

'''TODO: Revisit Part 1 and redefine 'M'

'''


import os
import numpy as np
import matplotlib.pyplot as plt


def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from sine model

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''
    RSTATE = np.random.RandomState(0)
    x = np.linspace(-np.pi, np.pi, count)
    y = np.sin(x) + RSTATE.normal(0, sigma, len(x))
    return x, y

def basis(x, M):
    '''Polynomial basis function of M-th degree

    @param {float} location to evaluate the basis functions
    @param {int} M degree of the basis functions
    @return {np.1darray} basis function values at `x`

    '''

    # Include bias
    N = M+1
    ret = np.zeros(N, dtype=np.float64)
    for i in range(N):
        ret[i] = x**i
    return ret

def design_matrix(x, M):
    '''Generate \Phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi matrix

    '''

    # Include bias term
    ret = np.zeros((len(x), M+1), dtype=np.float64)
    for i, x_ in enumerate(x):
        b = basis(x_, M)
        ret[i, :] = b
    return ret

def gram_matrix(x, M):
    '''Return \phi^{t} * \phi matrix

    @param {np.1darray} x observation points
    @param {int} M polynomial basis degree
    @return {np.2darray} Phi^{t}*Phi gram matrix
            {np.1darray} Phi design matrix

    '''
    p = design_matrix(x, M)
    return p.transpose().dot(p), p

def log_likelihood(x, t, M, alpha, beta):
    '''Evaluate the log likelihood for fixed alpha, beta

    @param {np.array1d} x observation points
    @param {np.array1d} t observation values
    @param {int} M model complexity
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return {float} log likelihood

    '''

    G, Phi = gram_matrix(x, M)
    A = beta*G + alpha*np.eye(M+1, M+1)
    det_A = np.linalg.det(A)
    m = beta * np.linalg.inv(A.transpose()) @ Phi.transpose().dot(t)
    Em = 0.5 * (beta*np.linalg.norm(t - Phi.dot(m))**2 + alpha*m.dot(m))
    N = len(x)
    return N/2.0 * np.log(beta) \
        - N/2.0 * np.log(2*np.pi) \
        + M/2.0 * np.log(alpha) \
        - 1/2.0 * np.log(det_A) \
        - Em



def log_likelihood_plot(x, t, complexities, alpha, beta):
    '''Plot log-likelihood vs different model complexities

    @param {iter[float]} x observation points
    @param {iter[float]} t observation values
    @param {iter[int]} complexities list of complexities to consider
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return TODO

    '''

    # log-likelihood vs complexity
    xp = complexities
    yp = [log_likelihood(x, t, m, alpha, beta) for m in complexities]

    xmin = np.min(x)
    xmax = np.max(x)
    xt = np.linspace(xmin, xmax, 50)

    # Likelihood vs complexity
    fig, ax = plt.subplots()
    ax.plot(xp, yp)
    ax.grid(True)

    param_string = r'$\alpha$ = {:.3g}, $\beta$ = {:.3g}'.format(alpha, beta)
    ax.set_title('Log likelihood\n' + param_string)
    ax.set_xlabel('Complexity, M')
    ax.set_ylabel(r'$\log \, p(\mathbf{t} | \alpha, \beta)$')
    return fig, ax

def train(x, t, M, alpha, beta):
    '''Evaluate the weights for a given model complexity

    @param {np.array1d} x input observation points
    @param {np.array1d} t observation data
    @param {int} M model complexity
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return {np.array1d} mean of the predictive distribution,
                         excluding the basis functions
            {np.array1d} covariance deviation matrix, excluding
                         basis functions

    '''

    G, Phi = gram_matrix(x, M)
    Phi_t = Phi.transpose().dot(t)
    Sn_inv = alpha * np.eye(M+1, M+1) + beta * G
    Sn = np.linalg.inv(Sn_inv)
    tr_Sn = np.trace(Sn)
    mn = beta * Sn.dot(Phi_t)
    return mn, Sn


def predict(x, mn, Sn):
    ''' Predict values observation values at `x`

    @param {float|np.1darray} location to evaluate model
    @param {np.1darray} weight mean
    @param {np.1darray} weight covariance, excluding basis functions
    @return {np.array} Mean (expected value) of the predicted value at @x
            {np.array} Variance of the predicted value at @x

    '''
    x = iter(x)
    M = len(mn)

    mean, std2 = [], []
    for x_ in x:
        phi = basis(x_, M-1)
        mean.append( mn.dot(phi) )
        std2.append( phi.transpose().dot(Sn).dot(phi) )
    return np.array(mean), np.array(std2)

def optimize_hyperparameters(x, t, M, iter_count = 10):
    N = len(x)
    # Prior distribution parameters
    # p(alpha) = Gamma(alpha | a0, b0)
    # p(beta) = Gamma(beta | c0, d0)
    a0, b0, c0, d0 = 1, 1, 1, 1
    alpha = 1
    beta = 10
    G, Phi = gram_matrix(x, M)
    Phi_t = Phi.transpose().dot(t)
    tt = t.dot(t)

    for i in range(iter_count):
        Sn_inv = alpha * np.eye(M+1, M+1) + beta * G
        Sn = np.linalg.inv(Sn_inv)
        tr_Sn = np.trace(Sn)
        mn = beta * Sn.dot(Phi_t)
        exp_wtw = mn.dot(mn) + tr_Sn
        exp_wwt = np.outer(mn, mn) + Sn
        aN = a0 + 0.5 * M
        bN = b0 + 0.5 * exp_wtw
        cN = c0 + N
        dN = d0 - 2*mn.dot(Phi_t) + tt + np.sum(G * exp_wwt)
        alpha = aN/bN
        beta = cN/dN
        print(i, alpha, beta)
    return alpha, beta


def main():
    M = 3
    x, t = get_data(50)
    alpha, beta = optimize_hyperparameters(x, t, M, 10)

    xp = np.arange(0, 16)
    fig, ax = log_likelihood_plot(x, t, xp,
                                  alpha, beta)
    fig.patch.set_alpha(0.0)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = '../img/variational-complexity-likelihood-a{:.3g}-b{:.3g}.svg'.format(alpha, beta)
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)

    # Plot model
    fig, ax = plt.subplots()
    xmin, xmax = np.min(x), np.max(x)
    xt = np.linspace(xmin, xmax, 50)
    w, Sn = train(x, t, M, alpha, beta)
    y, var = predict(xt, w, Sn)
    ax.plot(x, t, linestyle='', marker='o', markersize=3, label='Dataset')
    ax.plot(xt, np.sin(xt), linestyle='--', label='Target')
    ax.plot(xt, y, linewidth=2, label='Model Mean')
    ax.fill_between(xt,
                    y - 1*np.sqrt(var),
                    y + 1*np.sqrt(var),
                    color='red', alpha=0.30,
                    label=r'$E \pm {}\sigma$'.format(1))

    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    param_string = r'$\alpha$ = {:.3g}, $\beta$ = {:.3g}'.format(alpha, beta)
    ax.set_title('Models \n' + param_string)
    ax.grid(True)
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')

    fn = '../img/variational-opt-model-m{}.svg'.format(M)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, fn)
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


if __name__ == '__main__':
    main()