#!/usr/bin/python3

import os
import numpy as np
import matplotlib.pyplot as plt


RSTATE = np.random.RandomState(0)

def gauss_kernel(x, xp, sigma = 0.1):
    '''Gaussian Kernel, see Eq. (6.23)

    @param {float} x first argument
    @param {float} xp second argument
    @param {float sigma variance
    @param {float} kernel value at specified points
    '''
    return np.exp(-np.linalg.norm(x - xp)**2/(2*sigma**2))


def exp_kernel(x, xp, theta = 1):
    '''Exponential Kernel, see Eq. (6.65)

    @param {float} x first argument
    @param {float} xp second argument
    @param {float} sigma variance
    @param {float} kernel value at specified points
    '''
    return np.exp(-theta*np.abs(x - xp))


def gram_matrix(x, kernel):
    '''Gram matrix, see Eq. 6.55

    This is the variance of the Gaussian process

    @param {array} x sample poins
    @param {func} kernel callable kernel
    @return {numpy.arrayd} gram matrix with size (N, N) where N =
            len(x)
    '''
    return np.array([[kernel(x1, x2) for x2 in x] for x1 in x], np.float64)


def sample_gauss_process(x, kernel):
    ''' Sample a gaussian process with given kernel

    @param {numpy.array} x sample points
    @param {func} kernel callable kernel function
    @return {numpy.arrayd} single sample of a Guassian process
    '''

    # Variance
    K = gram_matrix(x, kernel)

    # have zero mean
    mean = np.zeros(len(x))

    # Same 'effect' as calling (multivariate_normal(mean, K, 1)
    # i.e. single N-dimensional sample
    return RSTATE.multivariate_normal(mean, K)


def plot_gprocess_gauss(x, N):
    '''
    '''
    fig, ax = plt.subplots()

    for i in range(5):
        ax.plot(x, sample_gauss_process(x, gauss_kernel))

    ax.grid(True)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title('Gaussian Kernel')
    fig.patch.set_alpha(0.0)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../img/gprocess-gauss-kernel.svg')
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


def plot_gprocess_exp(x, N):
    '''
    '''
    fig, ax = plt.subplots()

    for i in range(N):
        ax.plot(x, sample_gauss_process(x, exp_kernel))

    ax.grid(True)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title('Exponential Kernel')
    fig.patch.set_alpha(0.0)
    __dirname = os.path.dirname(os.path.realpath(__file__))
    fn = os.path.join(__dirname, '../img/gprocess-exp-kernel.svg')
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)


def main():
    # Sampling range
    x = np.linspace(-1, 1, 100)

    plot_gprocess_gauss(x, 5)
    plot_gprocess_exp(x, 5)


if __name__ == '__main__':
    main()