#!/usr/bin/python3

'''TODO: Revisit Part 1 and redefine 'M'

'''



import os
import numpy as np
import matplotlib.pyplot as plt


def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from sine model

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''
    RSTATE = np.random.RandomState(0)
    x = np.linspace(-np.pi, np.pi, count)
    y = np.sin(x) + RSTATE.normal(0, sigma, len(x))
    return x, y

def basis(x, M):
    '''Polynomial basis function of M-th degree

    @param {float} location to evaluate the basis functions
    @param {int} M degree of the basis functions
    @return {np.1darray} basis function values at `x`

    '''

    # Include bias
    N = M+1
    ret = np.zeros(N, dtype=np.float64)
    for i in range(N):
        ret[i] = x**i
    return ret

def design_matrix(x, M):
    '''Generate \Phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi matrix

    '''

    # Include bias term
    ret = np.zeros((len(x), M+1), dtype=np.float64)
    for i, x_ in enumerate(x):
        b = basis(x_, M)
        ret[i, :] = b
    return ret

def gram_matrix(x, M):
    '''Return \phi^{t} * \phi matrix

    @param {np.1darray} x observation points
    @param {int} M polynomial basis degree
    @return {np.2darray} Phi^{t}*Phi gram matrix
            {np.1darray} Phi design matrix

    '''
    p = design_matrix(x, M)
    return p.transpose().dot(p), p

def log_likelihood(x, t, M, alpha, beta):
    '''Evaluate the log likelihood for fixed alpha, beta

    @param {np.array1d} x observation points
    @param {np.array1d} t observation values
    @param {int} M model complexity
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return {float} log likelihood

    '''

    G, Phi = gram_matrix(x, M)
    A = beta*G + alpha*np.eye(M+1, M+1)
    det_A = np.linalg.det(A)
    m = beta * np.linalg.inv(A.transpose()) @ Phi.transpose().dot(t)
    Em = 0.5 * (beta*np.linalg.norm(t - Phi.dot(m))**2 + alpha*m.dot(m))
    N = len(x)
    return N/2.0 * np.log(beta) \
        - N/2.0 * np.log(2*np.pi) \
        + M/2.0 * np.log(alpha) \
        - 1/2.0 * np.log(det_A) \
        - Em

def train(x, t, M, alpha, beta):
    '''Evaluate the weights for a given model complexity

    @param {np.array1d} x input observation points
    @param {np.array1d} t observation data
    @param {int} M model complexity
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return {np.array1d} model weights

    '''
    N = len(x)
    gMatrix, dMatrix = gram_matrix(x, M)
    A = beta*gMatrix + alpha*np.eye(M+1, M+1, dtype=np.float64)
    b = beta*dMatrix.transpose().dot(t)
    return np.linalg.solve(A, b)

def predict(x, w):
    ''' Predict values observation values at `x`

    @param {float|np.1darray} location to evaluate model
    @param {np.1darray} model coefficients
    @return {np.array} Predicted values at @x
    '''
    x = iter(x)
    M = len(w)-1
    return np.array([np.inner(basis(x_, M), w) for x_ in x])

def model_plots(x, t, complexities, alpha, beta):
    # Model plots
    fig, ax = plt.subplots()
    # Plot the data-set
    ax.plot(x, t, linestyle='', marker='o', markersize=3)

    xmin = np.min(x)
    xmax = np.max(x)
    xt = np.linspace(xmin, xmax, 50)

    # Plot models with diff complexities trained on data set
    for m in complexities:
        w = train(x, t, m, alpha, beta)
        yt = predict(xt, w)
        ax.plot(xt, yt, label='M = {}'.format(m))

    param_string = r'$\alpha$ = {:.3g}, $\beta$ = {:.3g}'.format(alpha, beta)
    ax.set_title('Models \n' + param_string)
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')

    ax.grid(True)
    ax.legend(loc='best', fancybox=True,
              framealpha=0.5, fontsize='medium')
    return fig, ax

def log_likelihood_plot(x, t, complexities, alpha, beta):
    '''Plot log-likelihood vs different model complexities

    @param {iter[float]} x observation points
    @param {iter[float]} t observation values
    @param {iter[int]} complexities list of complexities to consider
    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return TODO

    '''

    # log-likelihood vs complexity
    xp = complexities
    yp = [log_likelihood(x, t, m, alpha, beta) for m in complexities]

    xmin = np.min(x)
    xmax = np.max(x)
    xt = np.linspace(xmin, xmax, 50)

    # Likelihood vs complexity
    fig, ax = plt.subplots()
    ax.plot(xp, yp)
    ax.grid(True)

    param_string = r'$\alpha$ = {:.3g}, $\beta$ = {:.3g}'.format(alpha, beta)
    ax.set_title('Log likelihood\n' + param_string)
    ax.set_xlabel('Complexity, M')
    ax.set_ylabel(r'$\log \, p(\mathbf{t} | \alpha, \beta)$')
    return fig, ax


def create_images(alpha, beta):
    '''Create evidence-complexity related images

    @param {float} alpha prior distribution precision
    @param {float} beta model precision
    @return {None}
    '''

    x, t = get_data(50)
    xp = np.arange(0, 16)

    yp = [log_likelihood(x, t, m, alpha, beta) for m in xp]

    __dirname = os.path.dirname(os.path.realpath(__file__))
    fig, ax = log_likelihood_plot(x, t, xp,
                                  alpha, beta)
    fig.patch.set_alpha(0.0)
    fn = '../img/evidence-complexity-likelihood-a{:.3g}-b{:.3g}.svg'.format(alpha, beta)
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)

    fig, ax = model_plots(x, t, [2, 3, 5, 10, 15], alpha, beta)
    fig.patch.set_alpha(0.0)
    fn = '../img/evidence-complexity-models-a{:.3g}-b{:.3g}.svg'.format(alpha, beta)
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)

def optimize_hyperparameters(x, t, M, iter_count = 10):
    N = len(x)
    # Initial guess
    alpha = 1
    beta = 10
    G, Phi = gram_matrix(x, M)
    lambdas_bar = np.linalg.eigvals(G)

    for i in range(iter_count):
        Sn_inv = alpha * np.eye(M+1, M+1) + beta * G
        Sn = np.linalg.inv(Sn_inv)
        lambdas = beta * lambdas_bar
        mn = beta * Sn.dot(Phi.transpose().dot(t))
        gamma = np.sum(lambdas/(alpha + lambdas))
        alpha = gamma/(mn.dot(mn))
        beta_inv = 1/(N - gamma) * np.linalg.norm(t - Phi.dot(mn))**2
        beta = 1.0/beta_inv
    return alpha, beta

def main():
    create_images(5E-6, 10)
    create_images(5E-6, 1)

    M = 3
    x, t = get_data(50)
    alpha, beta = optimize_hyperparameters(x, t, M, 10)
    create_images(alpha, beta)

    # Plot model
    fig, ax = plt.subplots()
    xmin, xmax = np.min(x), np.max(x)
    xt = np.linspace(xmin, xmax, 50)
    w = train(x, t, M, alpha, beta)
    y = predict(xt, w)
    ax.plot(x, t, linestyle='', marker='o', markersize=3, label='Dataset')
    ax.plot(xt, y, linewidth=2, label='Model')
    ax.plot(xt, np.sin(xt), linestyle='--', label='Target')

    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    param_string = r'$\alpha$ = {:.3g}, $\beta$ = {:.3g}'.format(alpha, beta)
    ax.set_title('Models \n' + param_string)
    ax.grid(True)
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')

    fn = '../img/evidence-opt-model-m{}.svg'.format(M)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, fn)
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)



if __name__ == '__main__':
    main()