#!/usr/bin/python3

'''TODO

'''

import os
import random
import numpy as np
import matplotlib.pyplot as plt


def get_data(count = 10, sigma = 0.3):
    '''Generate random perturbed data from sine model

    x values are sampled from a uniform distribution
    y values are sampled from a Gaussian distribution

    @param {int} count number of points returned
    @param {float} sigma standard deviation of the noise
    @return {np.1darray} observation points
            {np.1darray} perturbed observation values

    '''
    rstate = np.random.RandomState(0)
    x = rstate.uniform(-2*np.pi, 2*np.pi, count)
    y = np.sin(x) + rstate.normal(0, sigma, len(x))
    return x, y

def gaussianfn(mu, sigma2):
    '''Return a gaussian distribution generator function

    @param {float} mu numeric or a vector: mean
    @param {float|np.ndarray} sigma2 numeric or list: variance (sigma^2)
    @return {function} gaussian function generator

    '''

    if (isinstance(sigma2, float) or isinstance(sigma2, int)):
        coeff = 1.0/np.sqrt(2*np.pi*sigma2)
        return lambda x: coeff*np.exp(-(x-mu)**2/(2*sigma2))

    D = sigma2.size
    detSigma = np.linalg.det(sigma2)
    coeff = 1.0/(2*np.pi)**(D/2.0) * detSigma**0.5
    sigmaInverse = np.linalg.inv(sigma2)
    return lambda x: coeff*np.exp(-0.5*
                                  np.inner(
                                      (x-mu),
                                      np.dot(sigmaInverse, x-mu),
                                  ))

def predict(x_values, x, t, sigma2):
    '''Evaluate model mean using Guassian kernel regression

               sum_n{ t_n exp{-1/(2*s2) * (x - x_n)^2} }
    mean(x) = -------------------------------------------
                 sum_m{  exp{-1/(2*s2) * (x - x_m)^2 }
            = k(x, x_n) * t_n


    var ?= sigma^2 - mean^2 ???? double check

    @param {iterator} x_values 1d set of points to evaluate the mean
    @param {iterator} x 1d set of observation points
    @param {iterator} t 1d set of observation values
    @param {float} sigma2 standard deviation of the Gaussian
    @return {np.1darray} model mean

    '''
    mu = []
    var = []
    std = []

    for x_ in x_values:
        num, den = (0, 0)
        # Sum over training points
        for (xi, ti) in zip(x, t):
            v = np.exp(-1./(2*sigma2) * (x_ - xi)**2)
            num += ti * v
            den += v
        mu.append(num/den)
        v_ = sigma2 - (num/den)**2
        var.append(v_)
        # Overwrite negative variances with zero
        std.append(np.sqrt(np.max((0, v_))))

    return np.array(mu), np.array(std)

def predict_std(x_values, means, x, t, sigma2):
    '''TODO

    DEPRECATED?

    '''
    std = []
    g_num = gaussianfn(np.array([0, 0]),
                       np.eye(2, 2) * sigma2)

    for x_,mu_ in zip(x_values, means):
        num, den = (0, 0)
        # Sum over all training points
        for (xi, ti) in zip(x, t):
            v = np.exp(-1./(2*sigma2) * (x_ - xi)**2)
            num += g_num(np.array([x_-xi, mu_-ti]))
            den += v
        std.append(num/den)

    return np.array(std)

def plot_full_model(x, t, x_plot, N, sigma2, std_offset):
    '''Plot Nadaraya Watson model on data

    @param {iterator} x training points
    @param {iterator} t training values
    @param {iterator} x_plot values to plot model
    @param {int} N number of points to pick from the training data
    @param {float} sigma2 Standard deviation of the Gaussian kernel
    @param {float} std_offset Plotted standard deviation (half) width

    '''

    y, y_std = predict(x_plot, x, t, sigma2)

    fig, ax = plt.subplots()
    ax.plot(x, t, linestyle='', marker='o', markersize=3, label='Data set')
    ax.plot(x_plot, y, linewidth=2, label='Mean')
    ax.plot(x_plot, np.sin(x_plot), linestyle='--', label='Target')
    ax.fill_between(x_plot,
                    y - y_std*std_offset,
                    y + y_std*std_offset,
                    color='red', alpha=0.35,
                    label=r'$E \pm {}\sigma$'.format(std_offset))

    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')
    ax.set_ylim([-1.5, 1.5])
    ax.grid(True)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(r'$\sigma^2 = {}, N = {}$'.format(sigma2, N))
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../img/nadaraya-watson-gaussian.svg')
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


def plot_sampled_model(x, t, x_plot, N, sigma2, std_offset):
    '''Sample {x,t} and plot Nadaraya Watson model on data

    @param {iterator} x training points
    @param {iterator} t training values
    @param {iterator} x_plot values to plot model
    @param {int} N number of points to pick from the training data
    @param {float} sigma2 Standard deviation of the Gaussian kernel
    @param {float} std_offset Plotted standard deviation (half) width
    @return None
    '''

    # Make the sampling reproducible (and local)
    rand = random.Random()
    rand.seed(0)

    # Sample the data saet
    sampled = rand.sample(list(zip(x, t)), N)
    # Unpack the sampled data
    x_sampled, t_sampled  = zip(*sampled)

    # Mean and std of model
    # y = predict(x_plot, x_sampled, t_sampled, sigma2)
    y, y_std = predict(x_plot, x_sampled, t_sampled, sigma2)
    # y_std = predict_std(x_plot, y, x_sampled, t_sampled, sigma2)

    # Plot the data, mean, standard deviation, and picked data
    fig, ax = plt.subplots()
    ax.plot(x, t, linestyle='', marker='o', markersize=3, label='Data set')
    ax.plot(x_sampled, t_sampled,
            linestyle='', marker='x', markeredgewidth=2,
            color='r', label='Picked Points')
    ax.plot(x_plot, y, linewidth=2, label='Mean')
    ax.plot(x_plot, np.sin(x_plot), linestyle='--', label='Target')
    ax.fill_between(x_plot,
                    y - y_std*std_offset,
                    y + y_std*std_offset,
                    color='red', alpha=0.35,
                    label=r'$E \pm {}\sigma$'.format(std_offset))

    # Plot to file
    ax.legend(loc='best', fancybox=True, framealpha=0.5, fontsize='small')
    ax.set_ylim([-1.5, 1.5])
    ax.grid(True)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(r'$\sigma^2 = {}, N = {}$'.format(sigma2, N))
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../img/nadaraya-watson-gaussian-sampled.svg')
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


def main():
    # Number of samples
    N = 200
    # Standard deviation of the Gaussian kernel
    SIGMA2 = 1
    # Plotted standard deviation (half) width
    STD_OFFSET = 0.25

    # Sampled data
    x, t = get_data(N)
    x_plot = np.linspace(-2*np.pi, 2*np.pi, 100)

    plot_full_model(x, t, x_plot, N, SIGMA2, STD_OFFSET)

    plot_sampled_model(x, t, x_plot, int(N/5), SIGMA2, STD_OFFSET)


if __name__ == '__main__':
    main()