#!/usr/bin/python3


import os
import numpy as np
import matplotlib.pyplot as plt


def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from sine model

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''
    RSTATE = np.random.RandomState(0)
    x = np.linspace(-np.pi, np.pi, count)
    y = np.sin(x) + RSTATE.normal(0, sigma, len(x))
    return x, y

def basis(x, M):
    '''Polynomial basis function of (M-1)-th degree

    @param {float} location to evaluate the basis functions
    @param {int} M degree of the basis functions
    @return {np.1darray} basis function values at `x`

    '''

    # Include bias
    N = M
    ret = np.zeros(N, dtype=np.float64)
    for i in range(N):
        ret[i] = x**i
    return ret

def design_matrix(x, M):
    '''Generate \Phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi matrix

    '''

    ret = np.zeros((len(x), M), dtype=np.float64)
    for i, x_ in enumerate(x):
        b = basis(x_, M)
        ret[i, :] = b
    return ret

def gram_matrix(x, M):
    '''Return \phi^{t} * \phi matrix

    @param {np.1darray} x observation points
    @param {int} M basis order
    @return {np.2darray} Phi^{t}*Phi gram matrix
            {np.1darray} Phi design matrix

    '''
    p = design_matrix(x, M)
    return p.transpose().dot(p), p

def predict(x, w):
    '''Evaluate the weights for a given model complexity

    @param {np.array1d} x values to evalute model value
    @param {np.array1d} w model weights
    @return {np.array1d} model values evaluated at `x`

    '''
    M = len(w)
    x_itr = iter(x)
    ret = []
    for x_ in x_itr:
        y = 0
        for i in range(M):
            y += w[i] * x_**i
        ret.append(y)
    return np.array(ret)

def train(x, t, M, alpha_beta):
    '''Evaluate the weights for a given model complexity

    @param {np.array1d} x input observation points
    @param {np.array1d} t observation data
    @param {int} M model complexity
    @param {float} alpha/beta ratio
    @return {np.array1d} model weights

    '''
    N = len(x)
    gMatrix, dMatrix = gram_matrix(x, M)
    A = gMatrix + alpha_beta*np.eye(M, M)
    b = np.dot(dMatrix.transpose(), t)
    return np.linalg.solve(A, b)


def plot_results(x, t, M, alpha_beta):
    '''Plot data set, target, and model

    @param {np.array1d} x input observation points
    @param {np.array1d} t observation data
    @param {int} M model complexity
    @param {float} alpha/beta ratio
    @return {matplotlib.figure} Matplotlib figure object
            {matplotlib.axis} Matplotlib axis object

    '''
    fig, ax = plt.subplots()
    ax.set_xlabel('$x$')
    ax.set_ylabel('$y$')
    N = len(x)
    ax.set_title(r'$M = {}, N = {}, \alpha/\beta = {}$'.format(M, N, alpha_beta))

    x_plot = np.linspace(-np.pi, np.pi, 100)
    target = np.sin(x_plot)

    # Without regularizer
    w_a0 = train(x, t, M, 0)
    y_a0 = predict(x_plot, w_a0)
    # With regularizer
    w_reg = train(x, t, M, alpha_beta)
    y_reg = predict(x_plot, w_reg)

    ax.plot(x, t, linestyle='', marker='o', label='Data')
    ax.plot(x_plot, target, linestyle='--', label='Target')
    ax.plot(x_plot, y_a0, linestyle='-', label=r'Model, $\alpha/\beta$=0')
    ax.plot(x_plot, y_reg, linestyle='-', label=r'Model, $\alpha/\beta$={}'.format(alpha_beta))

    ax.grid(True)
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')

    fig.patch.set_alpha(0.0)
    return fig, ax


def main():
    # Regularizer coefficient: alpha/beta
    ALPHA_BETA = 10
    # Model complexity
    M = 10

    __dirname = os.path.dirname(os.path.realpath(__file__))

    # Regression with 10 data points
    N = 10
    x, t = get_data(N)
    fig, ax = plot_results(x, t, M, ALPHA_BETA)
    fn = '../img/regularized_regression_n{}.svg'.format(N)
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)

    # Regression with 100 data points
    N = 100
    x, t = get_data(N)
    M = 10
    fig, ax = plot_results(x, t, M, ALPHA_BETA)
    fn = '../img/regularized_regression_n{}.svg'.format(N)
    fn = os.path.join(__dirname, fn)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


if __name__ == '__main__':
    main()