#!/usr/bin/python3


import os
import matplotlib.pyplot as plt
import numpy as np


__dirname = os.path.dirname(os.path.realpath(__file__))

def create_dataset(point_count,
                   mean = [0, 0],
                   variance = 1,
                   seed = 0):
    '''Create a 2D data cluster with specified mean and variance

    @param {int} point_ount number of points
    @param {list} mean of the dataset
    @param {float} diagonal variance value
    @param {int} seed random seed number for reproducible results
    @return {list} List of (x, y) tuples

    '''
    RSTATE = np.random.RandomState(seed)
    cov = np.identity(2) * variance
    dset = []
    x, y = RSTATE.multivariate_normal(mean, cov, point_count).T
    return x, y


def train_mus(datasets):
    '''Compute in-class means given training set

    @param {list} datasets list of dataset, each with [x, y] coordinates
                 as vectors
    @return {list} trained list of in-class means
    '''
    mus = []
    for set_id, dset in enumerate(datasets):
        Nk = len(dset[0])
        mu = 0
        for i, (x, y) in enumerate(zip(*dset)):
            mu += 1.0/Nk * 1 * np.array([x, y])
        mus.append(mu)
    return mus


def train_pis(datasets):
    '''Compute class priors, C_k, given training set

    @param {list} datasets list of dataset, each with [x, y] coordinates
                 as vectors
    @param {mus} trained list of in-class means
    '''

    N = 0
    for dset in datasets:
        N += len(dset[0])
    ret = []
    for dset in datasets:
        Nk = len(dset[0])
        ret.append(Nk/float(N))
    return ret


def train_sigmas(datasets, mus):
    '''Compute the shared class covariance given training set

    @param {list} datasets list of dataset, each with [x, y] coordinates
                 as vectors
    @param {mus} trained list of in-class means
    @return {np.ndarray} (shared) class covariance

    '''
    sigma = np.zeros((2,2), np.float64)
    N = 0
    for dset in datasets:
        N += len(dset[0])
    for set_id, dset in enumerate(datasets):
        Nk = len(dset[0])
        for i, (x, y) in enumerate(zip(*dset)):
            v = np.array([x, y]) - mus[set_id]
            sigma += np.outer(v, v)
    return sigma


def train(datasets):
    '''Find in-class means, (shared) covariance, and priors given training
    dataset

    @param {list} datasets list of dataset, each with [x, y] coordinates
                 as vectors
    @return {list} class priors for each class
            {list} in-class means for each class
            {np.ndarray} (shared) class covariance

    '''
    mus = train_mus(datasets)
    sigma = train_sigmas(datasets, mus)
    pis = train_pis(datasets)
    return pis, mus, sigma


def gaussianfn(mu, sigma2):
    '''Return a gaussian function generator

    @param {float} mu numeric or a vector: mean
    @param {float|np.ndarray} sigma2 numeric or list: variance (sigma^2)
    @return {function} gaussian function generator

    '''

    if (isinstance(sigma2, float) or isinstance(sigma2, int)):
        coeff = 1.0/np.sqrt(2*np.pi*sigma2)
        return lambda x: coeff*np.exp(-(x-mu)**2/(2*sigma2))

    D = sigma2.size
    detSigma = np.linalg.det(sigma2)
    coeff = 1.0/(2*np.pi)**(D/2.0) * detSigma**0.5
    sigmaInverse = np.linalg.inv(sigma2)
    return lambda x: coeff*np.exp(-0.5*
                                  np.inner(
                                      (x-mu),
                                      np.dot(sigmaInverse, x-mu),
                                  ))

def predict(x, pis, mus, sigma):
    '''Return the classification region of input point

    @param {list} pis trained list of class prior distributions
    @param {list} mus trained list of in-class means for training set
    @param {np.ndarray} sigma shared covariance between classes
    @return {int} predicted classification region, indexed at zero
    '''
    posteriors = []
    sum_ak = 0
    for k, mu in enumerate(mus):
        gaussian = gaussianfn(mu, sigma)
        p_xc = gaussian(x)
        p_c = pis[k]
        sum_ak += np.log(p_xc * p_c)

    for k, mu in enumerate(mus):
        gaussian = gaussianfn(mu, sigma)
        p_xc = gaussian(x)
        p_c = pis[k]
        ak = np.log(p_xc * p_c)
        posteriors.append(ak/sum_ak)
    return np.argmax(posteriors)


def classify_grid(x, y, pis, mus, sigma):
    '''Classify a grid defined by [x, y]

    @param {list} x 1-d list of x coordinates
    @param {list} y 1-d list of y coordinates
    @param {list} pis trained list of class prior distributions
    @param {list} mus trained list of in-class means for training set
    @param {np.ndarray} sigma shared covariance between classes
    @return {np.ndarray} 2-dimensional array of int containing
                         region classification id, indexed at zero
    '''

    Z = np.zeros((len(x), len(y)), np.int32)
    for i, x_ in enumerate(x):
        for j, y_ in enumerate(y):
            Z[j, i] = predict(np.array([x_, y_]), pis, mus, sigma)
    return Z


def create_clusters(counts, means, seeds):
    '''Create random and reproducible set of clusters with `counts` points
       each

    @param {int} number of points per dataset
    @param {list} means list of generated class means
    @param {list} list of seeds used to generate random number
    @return [list{[x,y]}] List of (x,y) coordinate vectors

    '''
    return [create_dataset(counts, mean, seed) for mean, seed in zip(means, seeds)]


def plot_clusters(datasets):
    '''Plot dataset scatters

    @param {list} datasets list of dataset, each with [x, y]
                  coordinates as vectors
    @return {matplotlib.Figure} figure object
            {matplotlib.Axis} axis object

    '''

    fig, ax = plt.subplots()
    for i, dset in enumerate(datasets):
        ax.plot(dset[0], dset[1],
                linestyle='',
                marker='o',
                markersize=5,
                label='Set {}'.format(i+1))
    return fig, ax


def create_classification_plot(datasets):
    '''Evaluate classification regions and plot on top of dataset

    @param {list} datasets list of dataset, each with [x, y]
                  coordinates as vectors
    @return {matplotlib.Figure} figure object
            {matplotlib.Axis} axis object

    '''
    pis, mus, sigma = train(datasets)

    fig, ax = plot_clusters(datasets)

    xlim, ylim = ax.get_xlim(), ax.get_ylim()
    x = np.linspace(xlim[0], xlim[1], 100)
    y = np.linspace(ylim[0], ylim[1], 100)

    X, Y = np.meshgrid(x, y)
    Z = classify_grid(x, y, pis, mus, sigma)
    cplot = ax.contourf(X, Y, Z, alpha=0.8)

    ax.legend(loc='best', fancybox=True, framealpha=0.5)
    ax.grid(True)
    return fig, ax


def save_fig(fig, fn):
    '''
    '''
    fig.patch.set_alpha(0.0)
    plt.savefig(fn,
                facecolor=fig.get_facecolor(),
                edgecolor='none',
                bbox_inches=0)

def main():
    means = [[-3, -3],
             [+3, +2],
             [-3, +2],
             [+3, -2]]
    seeds = [1, 4, 8, 10]
    datasets = create_clusters(20, means, seeds)
    fig, ax = create_classification_plot(datasets)
    save_fig(fig,
             os.path.join(__dirname, '../img/prob-gen-ex1.svg'))

    means = [[-5, -5],
             [+5, +5],
             [-5, +5],
             [-3, +3],
             [+5, -5]]
    seeds = [1, 2, 3, 4, 5]
    datasets = create_clusters(20, means, seeds)
    fig, ax = create_classification_plot(datasets)
    save_fig(fig,
             os.path.join(__dirname, '../img/prob-gen-ex2.svg'))


if __name__ == '__main__':
    main()